diff --git a/cmd/admin.go b/cmd/admin.go new file mode 100644 index 0000000000000000000000000000000000000000..3421e5ddf4ee41724c9fe5bea1fb318d13f7becb --- /dev/null +++ b/cmd/admin.go @@ -0,0 +1,100 @@ +/* +Copyright © 2022 NAME HERE +*/ +package cmd + +import ( + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/pkg/utils/random" + "github.com/spf13/cobra" +) + +// AdminCmd represents the password command +var AdminCmd = &cobra.Command{ + Use: "admin", + Aliases: []string{"password"}, + Short: "Show admin user's info and some operations about admin user's password", + Run: func(cmd *cobra.Command, args []string) { + Init() + defer Release() + admin, err := op.GetAdmin() + if err != nil { + utils.Log.Errorf("failed get admin user: %+v", err) + } else { + utils.Log.Infof("Admin user's username: %s", admin.Username) + utils.Log.Infof("The password can only be output at the first startup, and then stored as a hash value, which cannot be reversed") + utils.Log.Infof("You can reset the password with a random string by running [alist admin random]") + utils.Log.Infof("You can also set a new password by running [alist admin set NEW_PASSWORD]") + } + }, +} + +var RandomPasswordCmd = &cobra.Command{ + Use: "random", + Short: "Reset admin user's password to a random string", + Run: func(cmd *cobra.Command, args []string) { + newPwd := random.String(8) + setAdminPassword(newPwd) + }, +} + +var SetPasswordCmd = &cobra.Command{ + Use: "set", + Short: "Set admin user's password", + Run: func(cmd *cobra.Command, args []string) { + if len(args) == 0 { + utils.Log.Errorf("Please enter the new password") + return + } + setAdminPassword(args[0]) + }, +} + +var ShowTokenCmd = &cobra.Command{ + Use: "token", + Short: "Show admin token", + Run: func(cmd *cobra.Command, args []string) { + Init() + defer Release() + token := setting.GetStr(conf.Token) + utils.Log.Infof("Admin token: %s", token) + }, +} + +func setAdminPassword(pwd string) { + Init() + defer Release() + admin, err := op.GetAdmin() + if err != nil { + utils.Log.Errorf("failed get admin user: %+v", err) + return + } + admin.SetPassword(pwd) + if err := op.UpdateUser(admin); err != nil { + utils.Log.Errorf("failed update admin user: %+v", err) + return + } + utils.Log.Infof("admin user has been updated:") + utils.Log.Infof("username: %s", admin.Username) + utils.Log.Infof("password: %s", pwd) + DelAdminCacheOnline() +} + +func init() { + RootCmd.AddCommand(AdminCmd) + AdminCmd.AddCommand(RandomPasswordCmd) + AdminCmd.AddCommand(SetPasswordCmd) + AdminCmd.AddCommand(ShowTokenCmd) + // Here you will define your flags and configuration settings. + + // Cobra supports Persistent Flags which will work for this command + // and all subcommands, e.g.: + // passwordCmd.PersistentFlags().String("foo", "", "A help for foo") + + // Cobra supports local flags which will only run when this command + // is called directly, e.g.: + // passwordCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") +} diff --git a/cmd/cancel2FA.go b/cmd/cancel2FA.go new file mode 100644 index 0000000000000000000000000000000000000000..08fafee84a8d405b2aae5745b7e8ddd201d3d005 --- /dev/null +++ b/cmd/cancel2FA.go @@ -0,0 +1,46 @@ +/* +Copyright © 2022 NAME HERE +*/ +package cmd + +import ( + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/spf13/cobra" +) + +// Cancel2FACmd represents the delete2fa command +var Cancel2FACmd = &cobra.Command{ + Use: "cancel2fa", + Short: "Delete 2FA of admin user", + Run: func(cmd *cobra.Command, args []string) { + Init() + defer Release() + admin, err := op.GetAdmin() + if err != nil { + utils.Log.Errorf("failed to get admin user: %+v", err) + } else { + err := op.Cancel2FAByUser(admin) + if err != nil { + utils.Log.Errorf("failed to cancel 2FA: %+v", err) + } else { + utils.Log.Info("2FA canceled") + DelAdminCacheOnline() + } + } + }, +} + +func init() { + RootCmd.AddCommand(Cancel2FACmd) + + // Here you will define your flags and configuration settings. + + // Cobra supports Persistent Flags which will work for this command + // and all subcommands, e.g.: + // cancel2FACmd.PersistentFlags().String("foo", "", "A help for foo") + + // Cobra supports local flags which will only run when this command + // is called directly, e.g.: + // cancel2FACmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") +} diff --git a/cmd/common.go b/cmd/common.go new file mode 100644 index 0000000000000000000000000000000000000000..b4a7081c33ff8e2071d2d820d36f3ace9e72a06a --- /dev/null +++ b/cmd/common.go @@ -0,0 +1,49 @@ +package cmd + +import ( + "os" + "path/filepath" + "strconv" + + "github.com/alist-org/alist/v3/internal/bootstrap" + "github.com/alist-org/alist/v3/internal/bootstrap/data" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/pkg/utils" + log "github.com/sirupsen/logrus" +) + +func Init() { + bootstrap.InitConfig() + bootstrap.Log() + bootstrap.InitDB() + data.InitData() + bootstrap.InitIndex() +} + +func Release() { + db.Close() +} + +var pid = -1 +var pidFile string + +func initDaemon() { + ex, err := os.Executable() + if err != nil { + log.Fatal(err) + } + exPath := filepath.Dir(ex) + _ = os.MkdirAll(filepath.Join(exPath, "daemon"), 0700) + pidFile = filepath.Join(exPath, "daemon/pid") + if utils.Exists(pidFile) { + bytes, err := os.ReadFile(pidFile) + if err != nil { + log.Fatal("failed to read pid file", err) + } + id, err := strconv.Atoi(string(bytes)) + if err != nil { + log.Fatal("failed to parse pid data", err) + } + pid = id + } +} diff --git a/cmd/flags/config.go b/cmd/flags/config.go new file mode 100644 index 0000000000000000000000000000000000000000..f74e2cb4becd3df0de02022a474bff92e4d33f1f --- /dev/null +++ b/cmd/flags/config.go @@ -0,0 +1,10 @@ +package flags + +var ( + DataDir string + Debug bool + NoPrefix bool + Dev bool + ForceBinDir bool + LogStd bool +) diff --git a/cmd/lang.go b/cmd/lang.go new file mode 100644 index 0000000000000000000000000000000000000000..8d816ca2b30003ba20e1c283e7e4dfe7d762e094 --- /dev/null +++ b/cmd/lang.go @@ -0,0 +1,161 @@ +/* +Package cmd +Copyright © 2022 Noah Hsu +*/ +package cmd + +import ( + "fmt" + "io" + "os" + "reflect" + "strings" + + _ "github.com/alist-org/alist/v3/drivers" + "github.com/alist-org/alist/v3/internal/bootstrap/data" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +type KV[V any] map[string]V + +type Drivers KV[KV[interface{}]] + +func firstUpper(s string) string { + if s == "" { + return "" + } + return strings.ToUpper(s[:1]) + s[1:] +} + +func convert(s string) string { + ss := strings.Split(s, "_") + ans := strings.Join(ss, " ") + return firstUpper(ans) +} + +func writeFile(name string, data interface{}) { + f, err := os.Open(fmt.Sprintf("../alist-web/src/lang/en/%s.json", name)) + if err != nil { + log.Errorf("failed to open %s.json: %+v", name, err) + return + } + defer f.Close() + content, err := io.ReadAll(f) + if err != nil { + log.Errorf("failed to read %s.json: %+v", name, err) + return + } + oldData := make(map[string]interface{}) + newData := make(map[string]interface{}) + err = utils.Json.Unmarshal(content, &oldData) + if err != nil { + log.Errorf("failed to unmarshal %s.json: %+v", name, err) + return + } + content, err = utils.Json.Marshal(data) + if err != nil { + log.Errorf("failed to marshal json: %+v", err) + return + } + err = utils.Json.Unmarshal(content, &newData) + if err != nil { + log.Errorf("failed to unmarshal json: %+v", err) + return + } + if reflect.DeepEqual(oldData, newData) { + log.Infof("%s.json no changed, skip", name) + } else { + log.Infof("%s.json changed, update file", name) + //log.Infof("old: %+v\nnew:%+v", oldData, data) + utils.WriteJsonToFile(fmt.Sprintf("lang/%s.json", name), newData, true) + } +} + +func generateDriversJson() { + drivers := make(Drivers) + drivers["drivers"] = make(KV[interface{}]) + drivers["config"] = make(KV[interface{}]) + driverInfoMap := op.GetDriverInfoMap() + for k, v := range driverInfoMap { + drivers["drivers"][k] = convert(k) + items := make(KV[interface{}]) + config := map[string]string{} + if v.Config.Alert != "" { + alert := strings.SplitN(v.Config.Alert, "|", 2) + if len(alert) > 1 { + config["alert"] = alert[1] + } + } + drivers["config"][k] = config + for i := range v.Additional { + item := v.Additional[i] + items[item.Name] = convert(item.Name) + if item.Help != "" { + items[fmt.Sprintf("%s-tips", item.Name)] = item.Help + } + if item.Type == conf.TypeSelect && len(item.Options) > 0 { + options := make(KV[string]) + _options := strings.Split(item.Options, ",") + for _, o := range _options { + options[o] = convert(o) + } + items[fmt.Sprintf("%ss", item.Name)] = options + } + } + drivers[k] = items + } + writeFile("drivers", drivers) +} + +func generateSettingsJson() { + settings := data.InitialSettings() + settingsLang := make(KV[any]) + for _, setting := range settings { + settingsLang[setting.Key] = convert(setting.Key) + if setting.Help != "" { + settingsLang[fmt.Sprintf("%s-tips", setting.Key)] = setting.Help + } + if setting.Type == conf.TypeSelect && len(setting.Options) > 0 { + options := make(KV[string]) + _options := strings.Split(setting.Options, ",") + for _, o := range _options { + options[o] = convert(o) + } + settingsLang[fmt.Sprintf("%ss", setting.Key)] = options + } + } + writeFile("settings", settingsLang) + //utils.WriteJsonToFile("lang/settings.json", settingsLang) +} + +// LangCmd represents the lang command +var LangCmd = &cobra.Command{ + Use: "lang", + Short: "Generate language json file", + Run: func(cmd *cobra.Command, args []string) { + err := os.MkdirAll("lang", 0777) + if err != nil { + utils.Log.Fatal("failed create folder: %s", err.Error()) + } + generateDriversJson() + generateSettingsJson() + }, +} + +func init() { + RootCmd.AddCommand(LangCmd) + + // Here you will define your flags and configuration settings. + + // Cobra supports Persistent Flags which will work for this command + // and all subcommands, e.g.: + // langCmd.PersistentFlags().String("foo", "", "A help for foo") + + // Cobra supports local flags which will only run when this command + // is called directly, e.g.: + // langCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") +} diff --git a/cmd/restart.go b/cmd/restart.go new file mode 100644 index 0000000000000000000000000000000000000000..7795747fb6d4c6a2e704a6f494479921c2848ba1 --- /dev/null +++ b/cmd/restart.go @@ -0,0 +1,32 @@ +/* +Copyright © 2022 NAME HERE +*/ +package cmd + +import ( + "github.com/spf13/cobra" +) + +// RestartCmd represents the restart command +var RestartCmd = &cobra.Command{ + Use: "restart", + Short: "Restart alist server by daemon/pid file", + Run: func(cmd *cobra.Command, args []string) { + stop() + start() + }, +} + +func init() { + RootCmd.AddCommand(RestartCmd) + + // Here you will define your flags and configuration settings. + + // Cobra supports Persistent Flags which will work for this command + // and all subcommands, e.g.: + // restartCmd.PersistentFlags().String("foo", "", "A help for foo") + + // Cobra supports local flags which will only run when this command + // is called directly, e.g.: + // restartCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") +} diff --git a/cmd/root.go b/cmd/root.go new file mode 100644 index 0000000000000000000000000000000000000000..6bd82b7a4a3103fb5d912ac2d9c2c4bc1bffb700 --- /dev/null +++ b/cmd/root.go @@ -0,0 +1,35 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/alist-org/alist/v3/cmd/flags" + _ "github.com/alist-org/alist/v3/drivers" + _ "github.com/alist-org/alist/v3/internal/offline_download" + "github.com/spf13/cobra" +) + +var RootCmd = &cobra.Command{ + Use: "alist", + Short: "A file list program that supports multiple storage.", + Long: `A file list program that supports multiple storage, +built with love by Xhofe and friends in Go/Solid.js. +Complete documentation is available at https://alist.nn.ci/`, +} + +func Execute() { + if err := RootCmd.Execute(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func init() { + RootCmd.PersistentFlags().StringVar(&flags.DataDir, "data", "data", "data folder") + RootCmd.PersistentFlags().BoolVar(&flags.Debug, "debug", false, "start with debug mode") + RootCmd.PersistentFlags().BoolVar(&flags.NoPrefix, "no-prefix", false, "disable env prefix") + RootCmd.PersistentFlags().BoolVar(&flags.Dev, "dev", false, "start with dev mode") + RootCmd.PersistentFlags().BoolVar(&flags.ForceBinDir, "force-bin-dir", false, "Force to use the directory where the binary file is located as data directory") + RootCmd.PersistentFlags().BoolVar(&flags.LogStd, "log-std", false, "Force to log to std") +} diff --git a/cmd/server.go b/cmd/server.go new file mode 100644 index 0000000000000000000000000000000000000000..8a7beafa7fdbee1bcf94324f5368beda9e39bfc2 --- /dev/null +++ b/cmd/server.go @@ -0,0 +1,181 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "os" + "os/signal" + "strconv" + "sync" + "syscall" + "time" + + "github.com/alist-org/alist/v3/cmd/flags" + "github.com/alist-org/alist/v3/internal/bootstrap" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +// ServerCmd represents the server command +var ServerCmd = &cobra.Command{ + Use: "server", + Short: "Start the server at the specified address", + Long: `Start the server at the specified address +the address is defined in config file`, + Run: func(cmd *cobra.Command, args []string) { + Init() + if conf.Conf.DelayedStart != 0 { + utils.Log.Infof("delayed start for %d seconds", conf.Conf.DelayedStart) + time.Sleep(time.Duration(conf.Conf.DelayedStart) * time.Second) + } + bootstrap.InitOfflineDownloadTools() + bootstrap.LoadStorages() + bootstrap.InitTaskManager() + if !flags.Debug && !flags.Dev { + gin.SetMode(gin.ReleaseMode) + } + r := gin.New() + r.Use(gin.LoggerWithWriter(log.StandardLogger().Out), gin.RecoveryWithWriter(log.StandardLogger().Out)) + server.Init(r) + var httpSrv, httpsSrv, unixSrv *http.Server + if conf.Conf.Scheme.HttpPort != -1 { + httpBase := fmt.Sprintf("%s:%d", conf.Conf.Scheme.Address, conf.Conf.Scheme.HttpPort) + utils.Log.Infof("start HTTP server @ %s", httpBase) + httpSrv = &http.Server{Addr: httpBase, Handler: r} + go func() { + err := httpSrv.ListenAndServe() + if err != nil && !errors.Is(err, http.ErrServerClosed) { + utils.Log.Fatalf("failed to start http: %s", err.Error()) + } + }() + } + if conf.Conf.Scheme.HttpsPort != -1 { + httpsBase := fmt.Sprintf("%s:%d", conf.Conf.Scheme.Address, conf.Conf.Scheme.HttpsPort) + utils.Log.Infof("start HTTPS server @ %s", httpsBase) + httpsSrv = &http.Server{Addr: httpsBase, Handler: r} + go func() { + err := httpsSrv.ListenAndServeTLS(conf.Conf.Scheme.CertFile, conf.Conf.Scheme.KeyFile) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + utils.Log.Fatalf("failed to start https: %s", err.Error()) + } + }() + } + if conf.Conf.Scheme.UnixFile != "" { + utils.Log.Infof("start unix server @ %s", conf.Conf.Scheme.UnixFile) + unixSrv = &http.Server{Handler: r} + go func() { + listener, err := net.Listen("unix", conf.Conf.Scheme.UnixFile) + if err != nil { + utils.Log.Fatalf("failed to listen unix: %+v", err) + } + // set socket file permission + mode, err := strconv.ParseUint(conf.Conf.Scheme.UnixFilePerm, 8, 32) + if err != nil { + utils.Log.Errorf("failed to parse socket file permission: %+v", err) + } else { + err = os.Chmod(conf.Conf.Scheme.UnixFile, os.FileMode(mode)) + if err != nil { + utils.Log.Errorf("failed to chmod socket file: %+v", err) + } + } + err = unixSrv.Serve(listener) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + utils.Log.Fatalf("failed to start unix: %s", err.Error()) + } + }() + } + if conf.Conf.S3.Port != -1 && conf.Conf.S3.Enable { + s3r := gin.New() + s3r.Use(gin.LoggerWithWriter(log.StandardLogger().Out), gin.RecoveryWithWriter(log.StandardLogger().Out)) + server.InitS3(s3r) + s3Base := fmt.Sprintf("%s:%d", conf.Conf.Scheme.Address, conf.Conf.S3.Port) + utils.Log.Infof("start S3 server @ %s", s3Base) + go func() { + var err error + if conf.Conf.S3.SSL { + httpsSrv = &http.Server{Addr: s3Base, Handler: s3r} + err = httpsSrv.ListenAndServeTLS(conf.Conf.Scheme.CertFile, conf.Conf.Scheme.KeyFile) + } + if !conf.Conf.S3.SSL { + httpSrv = &http.Server{Addr: s3Base, Handler: s3r} + err = httpSrv.ListenAndServe() + } + if err != nil && !errors.Is(err, http.ErrServerClosed) { + utils.Log.Fatalf("failed to start s3 server: %s", err.Error()) + } + }() + } + // Wait for interrupt signal to gracefully shutdown the server with + // a timeout of 1 second. + quit := make(chan os.Signal, 1) + // kill (no param) default send syscanll.SIGTERM + // kill -2 is syscall.SIGINT + // kill -9 is syscall. SIGKILL but can"t be catch, so don't need add it + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + utils.Log.Println("Shutdown server...") + Release() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + var wg sync.WaitGroup + if conf.Conf.Scheme.HttpPort != -1 { + wg.Add(1) + go func() { + defer wg.Done() + if err := httpSrv.Shutdown(ctx); err != nil { + utils.Log.Fatal("HTTP server shutdown err: ", err) + } + }() + } + if conf.Conf.Scheme.HttpsPort != -1 { + wg.Add(1) + go func() { + defer wg.Done() + if err := httpsSrv.Shutdown(ctx); err != nil { + utils.Log.Fatal("HTTPS server shutdown err: ", err) + } + }() + } + if conf.Conf.Scheme.UnixFile != "" { + wg.Add(1) + go func() { + defer wg.Done() + if err := unixSrv.Shutdown(ctx); err != nil { + utils.Log.Fatal("Unix server shutdown err: ", err) + } + }() + } + wg.Wait() + utils.Log.Println("Server exit") + }, +} + +func init() { + RootCmd.AddCommand(ServerCmd) + + // Here you will define your flags and configuration settings. + + // Cobra supports Persistent Flags which will work for this command + // and all subcommands, e.g.: + // serverCmd.PersistentFlags().String("foo", "", "A help for foo") + + // Cobra supports local flags which will only run when this command + // is called directly, e.g.: + // serverCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") +} + +// OutAlistInit 暴露用于外部启动server的函数 +func OutAlistInit() { + var ( + cmd *cobra.Command + args []string + ) + ServerCmd.Run(cmd, args) +} diff --git a/cmd/start.go b/cmd/start.go new file mode 100644 index 0000000000000000000000000000000000000000..cf447112bd79fe33058a271a4a349d0f7b896947 --- /dev/null +++ b/cmd/start.go @@ -0,0 +1,71 @@ +/* +Copyright © 2022 NAME HERE +*/ +package cmd + +import ( + "os" + "os/exec" + "path/filepath" + "strconv" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +// StartCmd represents the start command +var StartCmd = &cobra.Command{ + Use: "start", + Short: "Silent start alist server with `--force-bin-dir`", + Run: func(cmd *cobra.Command, args []string) { + start() + }, +} + +func start() { + initDaemon() + if pid != -1 { + _, err := os.FindProcess(pid) + if err == nil { + log.Info("alist already started, pid ", pid) + return + } + } + args := os.Args + args[1] = "server" + args = append(args, "--force-bin-dir") + cmd := &exec.Cmd{ + Path: args[0], + Args: args, + Env: os.Environ(), + } + stdout, err := os.OpenFile(filepath.Join(filepath.Dir(pidFile), "start.log"), os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666) + if err != nil { + log.Fatal(os.Getpid(), ": failed to open start log file:", err) + } + cmd.Stderr = stdout + cmd.Stdout = stdout + err = cmd.Start() + if err != nil { + log.Fatal("failed to start children process: ", err) + } + log.Infof("success start pid: %d", cmd.Process.Pid) + err = os.WriteFile(pidFile, []byte(strconv.Itoa(cmd.Process.Pid)), 0666) + if err != nil { + log.Warn("failed to record pid, you may not be able to stop the program with `./alist stop`") + } +} + +func init() { + RootCmd.AddCommand(StartCmd) + + // Here you will define your flags and configuration settings. + + // Cobra supports Persistent Flags which will work for this command + // and all subcommands, e.g.: + // startCmd.PersistentFlags().String("foo", "", "A help for foo") + + // Cobra supports local flags which will only run when this command + // is called directly, e.g.: + // startCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") +} diff --git a/cmd/stop.go b/cmd/stop.go new file mode 100644 index 0000000000000000000000000000000000000000..09fba7b759d950f86f773657f1d8f5dc9f7c63ef --- /dev/null +++ b/cmd/stop.go @@ -0,0 +1,58 @@ +/* +Copyright © 2022 NAME HERE +*/ +package cmd + +import ( + "os" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +// StopCmd represents the stop command +var StopCmd = &cobra.Command{ + Use: "stop", + Short: "Stop alist server by daemon/pid file", + Run: func(cmd *cobra.Command, args []string) { + stop() + }, +} + +func stop() { + initDaemon() + if pid == -1 { + log.Info("Seems not have been started. Try use `alist start` to start server.") + return + } + process, err := os.FindProcess(pid) + if err != nil { + log.Errorf("failed to find process by pid: %d, reason: %v", pid, process) + return + } + err = process.Kill() + if err != nil { + log.Errorf("failed to kill process %d: %v", pid, err) + } else { + log.Info("killed process: ", pid) + } + err = os.Remove(pidFile) + if err != nil { + log.Errorf("failed to remove pid file") + } + pid = -1 +} + +func init() { + RootCmd.AddCommand(StopCmd) + + // Here you will define your flags and configuration settings. + + // Cobra supports Persistent Flags which will work for this command + // and all subcommands, e.g.: + // stopCmd.PersistentFlags().String("foo", "", "A help for foo") + + // Cobra supports local flags which will only run when this command + // is called directly, e.g.: + // stopCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") +} diff --git a/cmd/storage.go b/cmd/storage.go new file mode 100644 index 0000000000000000000000000000000000000000..eabb5b40ba8f1a6968915a71e4fe15046258dd2f --- /dev/null +++ b/cmd/storage.go @@ -0,0 +1,163 @@ +/* +Copyright © 2023 NAME HERE +*/ +package cmd + +import ( + "os" + "strconv" + + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/charmbracelet/bubbles/table" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + "github.com/spf13/cobra" +) + +// storageCmd represents the storage command +var storageCmd = &cobra.Command{ + Use: "storage", + Short: "Manage storage", +} + +var disableStorageCmd = &cobra.Command{ + Use: "disable", + Short: "Disable a storage", + Run: func(cmd *cobra.Command, args []string) { + if len(args) < 1 { + utils.Log.Errorf("mount path is required") + return + } + mountPath := args[0] + Init() + defer Release() + storage, err := db.GetStorageByMountPath(mountPath) + if err != nil { + utils.Log.Errorf("failed to query storage: %+v", err) + } else { + storage.Disabled = true + err = db.UpdateStorage(storage) + if err != nil { + utils.Log.Errorf("failed to update storage: %+v", err) + } else { + utils.Log.Infof("Storage with mount path [%s] have been disabled", mountPath) + } + } + }, +} + +var baseStyle = lipgloss.NewStyle(). + BorderStyle(lipgloss.NormalBorder()). + BorderForeground(lipgloss.Color("240")) + +type model struct { + table table.Model +} + +func (m model) Init() tea.Cmd { return nil } + +func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmd tea.Cmd + switch msg := msg.(type) { + case tea.KeyMsg: + switch msg.String() { + case "esc": + if m.table.Focused() { + m.table.Blur() + } else { + m.table.Focus() + } + case "q", "ctrl+c": + return m, tea.Quit + //case "enter": + // return m, tea.Batch( + // tea.Printf("Let's go to %s!", m.table.SelectedRow()[1]), + // ) + } + } + m.table, cmd = m.table.Update(msg) + return m, cmd +} + +func (m model) View() string { + return baseStyle.Render(m.table.View()) + "\n" +} + +var storageTableHeight int +var listStorageCmd = &cobra.Command{ + Use: "list", + Short: "List all storages", + Run: func(cmd *cobra.Command, args []string) { + Init() + defer Release() + storages, _, err := db.GetStorages(1, -1) + if err != nil { + utils.Log.Errorf("failed to query storages: %+v", err) + } else { + utils.Log.Infof("Found %d storages", len(storages)) + columns := []table.Column{ + {Title: "ID", Width: 4}, + {Title: "Driver", Width: 16}, + {Title: "Mount Path", Width: 30}, + {Title: "Enabled", Width: 7}, + } + + var rows []table.Row + for i := range storages { + storage := storages[i] + enabled := "true" + if storage.Disabled { + enabled = "false" + } + rows = append(rows, table.Row{ + strconv.Itoa(int(storage.ID)), + storage.Driver, + storage.MountPath, + enabled, + }) + } + t := table.New( + table.WithColumns(columns), + table.WithRows(rows), + table.WithFocused(true), + table.WithHeight(storageTableHeight), + ) + + s := table.DefaultStyles() + s.Header = s.Header. + BorderStyle(lipgloss.NormalBorder()). + BorderForeground(lipgloss.Color("240")). + BorderBottom(true). + Bold(false) + s.Selected = s.Selected. + Foreground(lipgloss.Color("229")). + Background(lipgloss.Color("57")). + Bold(false) + t.SetStyles(s) + + m := model{t} + if _, err := tea.NewProgram(m).Run(); err != nil { + utils.Log.Errorf("failed to run program: %+v", err) + os.Exit(1) + } + } + }, +} + +func init() { + + RootCmd.AddCommand(storageCmd) + storageCmd.AddCommand(disableStorageCmd) + storageCmd.AddCommand(listStorageCmd) + storageCmd.PersistentFlags().IntVarP(&storageTableHeight, "height", "H", 10, "Table height") + // Here you will define your flags and configuration settings. + + // Cobra supports Persistent Flags which will work for this command + // and all subcommands, e.g.: + // storageCmd.PersistentFlags().String("foo", "", "A help for foo") + + // Cobra supports local flags which will only run when this command + // is called directly, e.g.: + // storageCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") +} diff --git a/cmd/user.go b/cmd/user.go new file mode 100644 index 0000000000000000000000000000000000000000..72cee5fa7ae3b8e7e7582a069a7891eadc12789c --- /dev/null +++ b/cmd/user.go @@ -0,0 +1,52 @@ +package cmd + +import ( + "crypto/tls" + "fmt" + "time" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +func DelAdminCacheOnline() { + admin, err := op.GetAdmin() + if err != nil { + utils.Log.Errorf("[del_admin_cache] get admin error: %+v", err) + return + } + DelUserCacheOnline(admin.Username) +} + +func DelUserCacheOnline(username string) { + client := resty.New().SetTimeout(1 * time.Second).SetTLSClientConfig(&tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}) + token := setting.GetStr(conf.Token) + port := conf.Conf.Scheme.HttpPort + u := fmt.Sprintf("http://localhost:%d/api/admin/user/del_cache", port) + if port == -1 { + if conf.Conf.Scheme.HttpsPort == -1 { + utils.Log.Warnf("[del_user_cache] no open port") + return + } + u = fmt.Sprintf("https://localhost:%d/api/admin/user/del_cache", conf.Conf.Scheme.HttpsPort) + } + res, err := client.R().SetHeader("Authorization", token).SetQueryParam("username", username).Post(u) + if err != nil { + utils.Log.Warnf("[del_user_cache_online] failed: %+v", err) + return + } + if res.StatusCode() != 200 { + utils.Log.Warnf("[del_user_cache_online] failed: %+v", res.String()) + return + } + code := utils.Json.Get(res.Body(), "code").ToInt() + msg := utils.Json.Get(res.Body(), "message").ToString() + if code != 200 { + utils.Log.Errorf("[del_user_cache_online] error: %s", msg) + return + } + utils.Log.Debugf("[del_user_cache_online] del user [%s] cache success", username) +} diff --git a/cmd/version.go b/cmd/version.go new file mode 100644 index 0000000000000000000000000000000000000000..cdf4d71fceed2823f6411defd8cac96c9f96ca54 --- /dev/null +++ b/cmd/version.go @@ -0,0 +1,43 @@ +/* +Copyright © 2022 NAME HERE +*/ +package cmd + +import ( + "fmt" + "os" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/spf13/cobra" +) + +// VersionCmd represents the version command +var VersionCmd = &cobra.Command{ + Use: "version", + Short: "Show current version of AList", + Run: func(cmd *cobra.Command, args []string) { + fmt.Printf(`Built At: %s +Go Version: %s +Author: %s +Commit ID: %s +Version: %s +WebVersion: %s +`, + conf.BuiltAt, conf.GoVersion, conf.GitAuthor, conf.GitCommit, conf.Version, conf.WebVersion) + os.Exit(0) + }, +} + +func init() { + RootCmd.AddCommand(VersionCmd) + + // Here you will define your flags and configuration settings. + + // Cobra supports Persistent Flags which will work for this command + // and all subcommands, e.g.: + // versionCmd.PersistentFlags().String("foo", "", "A help for foo") + + // Cobra supports local flags which will only run when this command + // is called directly, e.g.: + // versionCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") +} diff --git a/drivers/115/appver.go b/drivers/115/appver.go new file mode 100644 index 0000000000000000000000000000000000000000..78e11a5443f6aae5263aba0770e9dedb39105020 --- /dev/null +++ b/drivers/115/appver.go @@ -0,0 +1,43 @@ +package _115 + +import ( + driver115 "github.com/SheltonZhu/115driver/pkg/driver" + "github.com/alist-org/alist/v3/drivers/base" + log "github.com/sirupsen/logrus" +) + +var ( + md5Salt = "Qclm8MGWUv59TnrR0XPg" + appVer = "27.0.5.7" +) + +func (d *Pan115) getAppVersion() ([]driver115.AppVersion, error) { + result := driver115.VersionResp{} + resp, err := base.RestyClient.R().Get(driver115.ApiGetVersion) + + err = driver115.CheckErr(err, &result, resp) + if err != nil { + return nil, err + } + + return result.Data.GetAppVersions(), nil +} + +func (d *Pan115) getAppVer() string { + // todo add some cache? + vers, err := d.getAppVersion() + if err != nil { + log.Warnf("[115] get app version failed: %v", err) + return appVer + } + for _, ver := range vers { + if ver.AppName == "win" { + return ver.Version + } + } + return appVer +} + +func (d *Pan115) initAppVer() { + appVer = d.getAppVer() +} diff --git a/drivers/115/driver.go b/drivers/115/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..4f584cd7b513d2a3788e2c6de550051631976975 --- /dev/null +++ b/drivers/115/driver.go @@ -0,0 +1,251 @@ +package _115 + +import ( + "context" + "strings" + "sync" + + driver115 "github.com/SheltonZhu/115driver/pkg/driver" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" + "golang.org/x/time/rate" +) + +type Pan115 struct { + model.Storage + Addition + client *driver115.Pan115Client + limiter *rate.Limiter + appVerOnce sync.Once +} + +func (d *Pan115) Config() driver.Config { + return config +} + +func (d *Pan115) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Pan115) Init(ctx context.Context) error { + d.appVerOnce.Do(d.initAppVer) + if d.LimitRate > 0 { + d.limiter = rate.NewLimiter(rate.Limit(d.LimitRate), 1) + } + return d.login() +} + +func (d *Pan115) WaitLimit(ctx context.Context) error { + if d.limiter != nil { + return d.limiter.Wait(ctx) + } + return nil +} + +func (d *Pan115) Drop(ctx context.Context) error { + return nil +} + +func (d *Pan115) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if err := d.WaitLimit(ctx); err != nil { + return nil, err + } + files, err := d.getFiles(dir.GetID()) + if err != nil && !errors.Is(err, driver115.ErrNotExist) { + return nil, err + } + return utils.SliceConvert(files, func(src FileObj) (model.Obj, error) { + return &src, nil + }) +} + +func (d *Pan115) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if err := d.WaitLimit(ctx); err != nil { + return nil, err + } + userAgent := args.Header.Get("User-Agent") + downloadInfo, err := d. + DownloadWithUA(file.(*FileObj).PickCode, userAgent) + if err != nil { + return nil, err + } + link := &model.Link{ + URL: downloadInfo.Url.Url, + Header: downloadInfo.Header, + } + return link, nil +} + +func (d *Pan115) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + if err := d.WaitLimit(ctx); err != nil { + return nil, err + } + + result := driver115.MkdirResp{} + form := map[string]string{ + "pid": parentDir.GetID(), + "cname": dirName, + } + req := d.client.NewRequest(). + SetFormData(form). + SetResult(&result). + ForceContentType("application/json;charset=UTF-8") + + resp, err := req.Post(driver115.ApiDirAdd) + + err = driver115.CheckErr(err, &result, resp) + if err != nil { + return nil, err + } + f, err := d.getNewFile(result.FileID) + if err != nil { + return nil, nil + } + return f, nil +} + +func (d *Pan115) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + if err := d.WaitLimit(ctx); err != nil { + return nil, err + } + if err := d.client.Move(dstDir.GetID(), srcObj.GetID()); err != nil { + return nil, err + } + f, err := d.getNewFile(srcObj.GetID()) + if err != nil { + return nil, nil + } + return f, nil +} + +func (d *Pan115) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + if err := d.WaitLimit(ctx); err != nil { + return nil, err + } + if err := d.client.Rename(srcObj.GetID(), newName); err != nil { + return nil, err + } + f, err := d.getNewFile((srcObj.GetID())) + if err != nil { + return nil, nil + } + return f, nil +} + +func (d *Pan115) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + if err := d.WaitLimit(ctx); err != nil { + return err + } + return d.client.Copy(dstDir.GetID(), srcObj.GetID()) +} + +func (d *Pan115) Remove(ctx context.Context, obj model.Obj) error { + if err := d.WaitLimit(ctx); err != nil { + return err + } + return d.client.Delete(obj.GetID()) +} + +func (d *Pan115) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + if err := d.WaitLimit(ctx); err != nil { + return nil, err + } + + var ( + fastInfo *driver115.UploadInitResp + dirID = dstDir.GetID() + ) + + if ok, err := d.client.UploadAvailable(); err != nil || !ok { + return nil, err + } + if stream.GetSize() > d.client.UploadMetaInfo.SizeLimit { + return nil, driver115.ErrUploadTooLarge + } + //if digest, err = d.client.GetDigestResult(stream); err != nil { + // return err + //} + + const PreHashSize int64 = 128 * utils.KB + hashSize := PreHashSize + if stream.GetSize() < PreHashSize { + hashSize = stream.GetSize() + } + reader, err := stream.RangeRead(http_range.Range{Start: 0, Length: hashSize}) + if err != nil { + return nil, err + } + preHash, err := utils.HashReader(utils.SHA1, reader) + if err != nil { + return nil, err + } + preHash = strings.ToUpper(preHash) + fullHash := stream.GetHash().GetHash(utils.SHA1) + if len(fullHash) <= 0 { + tmpF, err := stream.CacheFullInTempFile() + if err != nil { + return nil, err + } + fullHash, err = utils.HashFile(utils.SHA1, tmpF) + if err != nil { + return nil, err + } + } + fullHash = strings.ToUpper(fullHash) + + // rapid-upload + // note that 115 add timeout for rapid-upload, + // and "sig invalid" err is thrown even when the hash is correct after timeout. + if fastInfo, err = d.rapidUpload(stream.GetSize(), stream.GetName(), dirID, preHash, fullHash, stream); err != nil { + return nil, err + } + if matched, err := fastInfo.Ok(); err != nil { + return nil, err + } else if matched { + f, err := d.getNewFileByPickCode(fastInfo.PickCode) + if err != nil { + return nil, nil + } + return f, nil + } + + var uploadResult *UploadResult + // 闪传失败,上传 + if stream.GetSize() <= 10*utils.MB { // 文件大小小于10MB,改用普通模式上传 + if uploadResult, err = d.UploadByOSS(&fastInfo.UploadOSSParams, stream, dirID); err != nil { + return nil, err + } + } else { + // 分片上传 + if uploadResult, err = d.UploadByMultipart(&fastInfo.UploadOSSParams, stream.GetSize(), stream, dirID); err != nil { + return nil, err + } + } + + file, err := d.getNewFile(uploadResult.Data.FileID) + if err != nil { + return nil, nil + } + return file, nil +} + +func (d *Pan115) OfflineList(ctx context.Context) ([]*driver115.OfflineTask, error) { + resp, err := d.client.ListOfflineTask(0) + if err != nil { + return nil, err + } + return resp.Tasks, nil +} + +func (d *Pan115) OfflineDownload(ctx context.Context, uris []string, dstDir model.Obj) ([]string, error) { + return d.client.AddOfflineTaskURIs(uris, dstDir.GetID()) +} + +func (d *Pan115) DeleteOfflineTasks(ctx context.Context, hashes []string, deleteFiles bool) error { + return d.client.DeleteOfflineTasks(hashes, deleteFiles) +} + +var _ driver.Driver = (*Pan115)(nil) diff --git a/drivers/115/meta.go b/drivers/115/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..bcea174922c5e44abc54554bc74a07b43a3813ee --- /dev/null +++ b/drivers/115/meta.go @@ -0,0 +1,29 @@ +package _115 + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + Cookie string `json:"cookie" type:"text" help:"one of QR code token and cookie required"` + QRCodeToken string `json:"qrcode_token" type:"text" help:"one of QR code token and cookie required"` + QRCodeSource string `json:"qrcode_source" type:"select" options:"web,android,ios,tv,alipaymini,wechatmini,qandroid" default:"linux" help:"select the QR code device, default linux"` + PageSize int64 `json:"page_size" type:"number" default:"1000" help:"list api per page size of 115 driver"` + LimitRate float64 `json:"limit_rate" type:"float" default:"2" help:"limit all api request rate ([limit]r/1s)"` + driver.RootID +} + +var config = driver.Config{ + Name: "115 Cloud", + DefaultRoot: "0", + // OnlyProxy: true, + // OnlyLocal: true, + // NoOverwriteUpload: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Pan115{} + }) +} diff --git a/drivers/115/types.go b/drivers/115/types.go new file mode 100644 index 0000000000000000000000000000000000000000..40b951d80ce4dd9b65a1aa07f4fc347cb5d28516 --- /dev/null +++ b/drivers/115/types.go @@ -0,0 +1,38 @@ +package _115 + +import ( + "time" + + "github.com/SheltonZhu/115driver/pkg/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" +) + +var _ model.Obj = (*FileObj)(nil) + +type FileObj struct { + driver.File +} + +func (f *FileObj) CreateTime() time.Time { + return f.File.CreateTime +} + +func (f *FileObj) GetHash() utils.HashInfo { + return utils.NewHashInfo(utils.SHA1, f.Sha1) +} + +type UploadResult struct { + driver.BasicResp + Data struct { + PickCode string `json:"pick_code"` + FileSize int `json:"file_size"` + FileID string `json:"file_id"` + ThumbURL string `json:"thumb_url"` + Sha1 string `json:"sha1"` + Aid int `json:"aid"` + FileName string `json:"file_name"` + Cid string `json:"cid"` + IsVideo int `json:"is_video"` + } `json:"data"` +} diff --git a/drivers/115/util.go b/drivers/115/util.go new file mode 100644 index 0000000000000000000000000000000000000000..d7a1adff71cc9a3e6443ef8b0cf7ef0e88099ad1 --- /dev/null +++ b/drivers/115/util.go @@ -0,0 +1,537 @@ +package _115 + +import ( + "bytes" + "crypto/md5" + "crypto/tls" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/aliyun/aliyun-oss-go-sdk/oss" + + driver115 "github.com/SheltonZhu/115driver/pkg/driver" + crypto "github.com/gaoyb7/115drive-webdav/115" + "github.com/orzogc/fake115uploader/cipher" + "github.com/pkg/errors" +) + +// var UserAgent = driver115.UA115Browser +func (d *Pan115) login() error { + var err error + opts := []driver115.Option{ + driver115.UA(d.getUA()), + func(c *driver115.Pan115Client) { + c.Client.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}) + }, + } + d.client = driver115.New(opts...) + cr := &driver115.Credential{} + if d.QRCodeToken != "" { + s := &driver115.QRCodeSession{ + UID: d.QRCodeToken, + } + if cr, err = d.client.QRCodeLoginWithApp(s, driver115.LoginApp(d.QRCodeSource)); err != nil { + return errors.Wrap(err, "failed to login by qrcode") + } + d.Cookie = fmt.Sprintf("UID=%s;CID=%s;SEID=%s;KID=%s", cr.UID, cr.CID, cr.SEID, cr.KID) + d.QRCodeToken = "" + } else if d.Cookie != "" { + if err = cr.FromCookie(d.Cookie); err != nil { + return errors.Wrap(err, "failed to login by cookies") + } + d.client.ImportCredential(cr) + } else { + return errors.New("missing cookie or qrcode account") + } + return d.client.LoginCheck() +} + +func (d *Pan115) getFiles(fileId string) ([]FileObj, error) { + res := make([]FileObj, 0) + if d.PageSize <= 0 { + d.PageSize = driver115.FileListLimit + } + files, err := d.client.ListWithLimit(fileId, d.PageSize) + if err != nil { + return nil, err + } + for _, file := range *files { + res = append(res, FileObj{file}) + } + return res, nil +} + +func (d *Pan115) getNewFile(fileId string) (*FileObj, error) { + file, err := d.client.GetFile(fileId) + if err != nil { + return nil, err + } + return &FileObj{*file}, nil +} + +func (d *Pan115) getNewFileByPickCode(pickCode string) (*FileObj, error) { + result := driver115.GetFileInfoResponse{} + req := d.client.NewRequest(). + SetQueryParam("pick_code", pickCode). + ForceContentType("application/json;charset=UTF-8"). + SetResult(&result) + resp, err := req.Get(driver115.ApiFileInfo) + if err := driver115.CheckErr(err, &result, resp); err != nil { + return nil, err + } + if len(result.Files) == 0 { + return nil, errors.New("not get file info") + } + fileInfo := result.Files[0] + + f := &FileObj{} + f.From(fileInfo) + return f, nil +} + +func (d *Pan115) getUA() string { + return fmt.Sprintf("Mozilla/5.0 115Browser/%s", appVer) +} + +func (d *Pan115) DownloadWithUA(pickCode, ua string) (*driver115.DownloadInfo, error) { + key := crypto.GenerateKey() + result := driver115.DownloadResp{} + params, err := utils.Json.Marshal(map[string]string{"pickcode": pickCode}) + if err != nil { + return nil, err + } + + data := crypto.Encode(params, key) + + bodyReader := strings.NewReader(url.Values{"data": []string{data}}.Encode()) + reqUrl := fmt.Sprintf("%s?t=%s", driver115.ApiDownloadGetUrl, driver115.Now().String()) + req, _ := http.NewRequest(http.MethodPost, reqUrl, bodyReader) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Cookie", d.Cookie) + req.Header.Set("User-Agent", ua) + + resp, err := d.client.Client.GetClient().Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + if err := utils.Json.Unmarshal(body, &result); err != nil { + return nil, err + } + + if err = result.Err(string(body)); err != nil { + return nil, err + } + + bytes, err := crypto.Decode(string(result.EncodedData), key) + if err != nil { + return nil, err + } + + downloadInfo := driver115.DownloadData{} + if err := utils.Json.Unmarshal(bytes, &downloadInfo); err != nil { + return nil, err + } + + for _, info := range downloadInfo { + if info.FileSize < 0 { + return nil, driver115.ErrDownloadEmpty + } + info.Header = resp.Request.Header + return info, nil + } + return nil, driver115.ErrUnexpected +} + +func (c *Pan115) GenerateToken(fileID, preID, timeStamp, fileSize, signKey, signVal string) string { + userID := strconv.FormatInt(c.client.UserID, 10) + userIDMd5 := md5.Sum([]byte(userID)) + tokenMd5 := md5.Sum([]byte(md5Salt + fileID + fileSize + signKey + signVal + userID + timeStamp + hex.EncodeToString(userIDMd5[:]) + appVer)) + return hex.EncodeToString(tokenMd5[:]) +} + +func (d *Pan115) rapidUpload(fileSize int64, fileName, dirID, preID, fileID string, stream model.FileStreamer) (*driver115.UploadInitResp, error) { + var ( + ecdhCipher *cipher.EcdhCipher + encrypted []byte + decrypted []byte + encodedToken string + err error + target = "U_1_" + dirID + bodyBytes []byte + result = driver115.UploadInitResp{} + fileSizeStr = strconv.FormatInt(fileSize, 10) + ) + if ecdhCipher, err = cipher.NewEcdhCipher(); err != nil { + return nil, err + } + + userID := strconv.FormatInt(d.client.UserID, 10) + form := url.Values{} + form.Set("appid", "0") + form.Set("appversion", appVer) + form.Set("userid", userID) + form.Set("filename", fileName) + form.Set("filesize", fileSizeStr) + form.Set("fileid", fileID) + form.Set("target", target) + form.Set("sig", d.client.GenerateSignature(fileID, target)) + + signKey, signVal := "", "" + for retry := true; retry; { + t := driver115.NowMilli() + + if encodedToken, err = ecdhCipher.EncodeToken(t.ToInt64()); err != nil { + return nil, err + } + + params := map[string]string{ + "k_ec": encodedToken, + } + + form.Set("t", t.String()) + form.Set("token", d.GenerateToken(fileID, preID, t.String(), fileSizeStr, signKey, signVal)) + if signKey != "" && signVal != "" { + form.Set("sign_key", signKey) + form.Set("sign_val", signVal) + } + if encrypted, err = ecdhCipher.Encrypt([]byte(form.Encode())); err != nil { + return nil, err + } + + req := d.client.NewRequest(). + SetQueryParams(params). + SetBody(encrypted). + SetHeaderVerbatim("Content-Type", "application/x-www-form-urlencoded"). + SetDoNotParseResponse(true) + resp, err := req.Post(driver115.ApiUploadInit) + if err != nil { + return nil, err + } + data := resp.RawBody() + defer data.Close() + if bodyBytes, err = io.ReadAll(data); err != nil { + return nil, err + } + if decrypted, err = ecdhCipher.Decrypt(bodyBytes); err != nil { + return nil, err + } + if err = driver115.CheckErr(json.Unmarshal(decrypted, &result), &result, resp); err != nil { + return nil, err + } + if result.Status == 7 { + // Update signKey & signVal + signKey = result.SignKey + signVal, err = UploadDigestRange(stream, result.SignCheck) + if err != nil { + return nil, err + } + } else { + retry = false + } + result.SHA1 = fileID + } + + return &result, nil +} + +func UploadDigestRange(stream model.FileStreamer, rangeSpec string) (result string, err error) { + var start, end int64 + if _, err = fmt.Sscanf(rangeSpec, "%d-%d", &start, &end); err != nil { + return + } + + length := end - start + 1 + reader, err := stream.RangeRead(http_range.Range{Start: start, Length: length}) + if err != nil { + return "", err + } + hashStr, err := utils.HashReader(utils.SHA1, reader) + if err != nil { + return "", err + } + result = strings.ToUpper(hashStr) + return +} + +// UploadByOSS use aliyun sdk to upload +func (c *Pan115) UploadByOSS(params *driver115.UploadOSSParams, r io.Reader, dirID string) (*UploadResult, error) { + ossToken, err := c.client.GetOSSToken() + if err != nil { + return nil, err + } + ossClient, err := oss.New(driver115.OSSEndpoint, ossToken.AccessKeyID, ossToken.AccessKeySecret) + if err != nil { + return nil, err + } + bucket, err := ossClient.Bucket(params.Bucket) + if err != nil { + return nil, err + } + + var bodyBytes []byte + if err = bucket.PutObject(params.Object, r, append( + driver115.OssOption(params, ossToken), + oss.CallbackResult(&bodyBytes), + )...); err != nil { + return nil, err + } + + var uploadResult UploadResult + if err = json.Unmarshal(bodyBytes, &uploadResult); err != nil { + return nil, err + } + return &uploadResult, uploadResult.Err(string(bodyBytes)) +} + +// UploadByMultipart upload by mutipart blocks +func (d *Pan115) UploadByMultipart(params *driver115.UploadOSSParams, fileSize int64, stream model.FileStreamer, dirID string, opts ...driver115.UploadMultipartOption) (*UploadResult, error) { + var ( + chunks []oss.FileChunk + parts []oss.UploadPart + imur oss.InitiateMultipartUploadResult + ossClient *oss.Client + bucket *oss.Bucket + ossToken *driver115.UploadOSSTokenResp + bodyBytes []byte + err error + ) + + tmpF, err := stream.CacheFullInTempFile() + if err != nil { + return nil, err + } + + options := driver115.DefalutUploadMultipartOptions() + if len(opts) > 0 { + for _, f := range opts { + f(options) + } + } + // oss 启用Sequential必须按顺序上传 + options.ThreadsNum = 1 + + if ossToken, err = d.client.GetOSSToken(); err != nil { + return nil, err + } + + if ossClient, err = oss.New(driver115.OSSEndpoint, ossToken.AccessKeyID, ossToken.AccessKeySecret, oss.EnableMD5(true), oss.EnableCRC(true)); err != nil { + return nil, err + } + + if bucket, err = ossClient.Bucket(params.Bucket); err != nil { + return nil, err + } + + // ossToken一小时后就会失效,所以每50分钟重新获取一次 + ticker := time.NewTicker(options.TokenRefreshTime) + defer ticker.Stop() + // 设置超时 + timeout := time.NewTimer(options.Timeout) + + if chunks, err = SplitFile(fileSize); err != nil { + return nil, err + } + + if imur, err = bucket.InitiateMultipartUpload(params.Object, + oss.SetHeader(driver115.OssSecurityTokenHeaderName, ossToken.SecurityToken), + oss.UserAgentHeader(driver115.OSSUserAgent), + oss.EnableSha1(), oss.Sequential(), + ); err != nil { + return nil, err + } + + wg := sync.WaitGroup{} + wg.Add(len(chunks)) + + chunksCh := make(chan oss.FileChunk) + errCh := make(chan error) + UploadedPartsCh := make(chan oss.UploadPart) + quit := make(chan struct{}) + + // producer + go chunksProducer(chunksCh, chunks) + go func() { + wg.Wait() + quit <- struct{}{} + }() + + // consumers + for i := 0; i < options.ThreadsNum; i++ { + go func(threadId int) { + defer func() { + if r := recover(); r != nil { + errCh <- fmt.Errorf("recovered in %v", r) + } + }() + for chunk := range chunksCh { + var part oss.UploadPart // 出现错误就继续尝试,共尝试3次 + for retry := 0; retry < 3; retry++ { + select { + case <-ticker.C: + if ossToken, err = d.client.GetOSSToken(); err != nil { // 到时重新获取ossToken + errCh <- errors.Wrap(err, "刷新token时出现错误") + } + default: + } + + buf := make([]byte, chunk.Size) + if _, err = tmpF.ReadAt(buf, chunk.Offset); err != nil && !errors.Is(err, io.EOF) { + continue + } + + if part, err = bucket.UploadPart(imur, bytes.NewBuffer(buf), chunk.Size, chunk.Number, driver115.OssOption(params, ossToken)...); err == nil { + break + } + } + if err != nil { + errCh <- errors.Wrap(err, fmt.Sprintf("上传 %s 的第%d个分片时出现错误:%v", stream.GetName(), chunk.Number, err)) + } + UploadedPartsCh <- part + } + }(i) + } + + go func() { + for part := range UploadedPartsCh { + parts = append(parts, part) + wg.Done() + } + }() +LOOP: + for { + select { + case <-ticker.C: + // 到时重新获取ossToken + if ossToken, err = d.client.GetOSSToken(); err != nil { + return nil, err + } + case <-quit: + break LOOP + case <-errCh: + return nil, err + case <-timeout.C: + return nil, fmt.Errorf("time out") + } + } + + // 不知道啥原因,oss那边分片上传不计算sha1,导致115服务器校验错误 + // params.Callback.Callback = strings.ReplaceAll(params.Callback.Callback, "${sha1}", params.SHA1) + if _, err := bucket.CompleteMultipartUpload(imur, parts, append( + driver115.OssOption(params, ossToken), + oss.CallbackResult(&bodyBytes), + )...); err != nil { + return nil, err + } + + var uploadResult UploadResult + if err = json.Unmarshal(bodyBytes, &uploadResult); err != nil { + return nil, err + } + return &uploadResult, uploadResult.Err(string(bodyBytes)) +} + +func chunksProducer(ch chan oss.FileChunk, chunks []oss.FileChunk) { + for _, chunk := range chunks { + ch <- chunk + } +} + +func SplitFile(fileSize int64) (chunks []oss.FileChunk, err error) { + for i := int64(1); i < 10; i++ { + if fileSize < i*utils.GB { // 文件大小小于iGB时分为i*1000片 + if chunks, err = SplitFileByPartNum(fileSize, int(i*1000)); err != nil { + return + } + break + } + } + if fileSize > 9*utils.GB { // 文件大小大于9GB时分为10000片 + if chunks, err = SplitFileByPartNum(fileSize, 10000); err != nil { + return + } + } + // 单个分片大小不能小于100KB + if chunks[0].Size < 100*utils.KB { + if chunks, err = SplitFileByPartSize(fileSize, 100*utils.KB); err != nil { + return + } + } + return +} + +// SplitFileByPartNum splits big file into parts by the num of parts. +// Split the file with specified parts count, returns the split result when error is nil. +func SplitFileByPartNum(fileSize int64, chunkNum int) ([]oss.FileChunk, error) { + if chunkNum <= 0 || chunkNum > 10000 { + return nil, errors.New("chunkNum invalid") + } + + if int64(chunkNum) > fileSize { + return nil, errors.New("oss: chunkNum invalid") + } + + var chunks []oss.FileChunk + chunk := oss.FileChunk{} + chunkN := (int64)(chunkNum) + for i := int64(0); i < chunkN; i++ { + chunk.Number = int(i + 1) + chunk.Offset = i * (fileSize / chunkN) + if i == chunkN-1 { + chunk.Size = fileSize/chunkN + fileSize%chunkN + } else { + chunk.Size = fileSize / chunkN + } + chunks = append(chunks, chunk) + } + + return chunks, nil +} + +// SplitFileByPartSize splits big file into parts by the size of parts. +// Splits the file by the part size. Returns the FileChunk when error is nil. +func SplitFileByPartSize(fileSize int64, chunkSize int64) ([]oss.FileChunk, error) { + if chunkSize <= 0 { + return nil, errors.New("chunkSize invalid") + } + + chunkN := fileSize / chunkSize + if chunkN >= 10000 { + return nil, errors.New("Too many parts, please increase part size") + } + + var chunks []oss.FileChunk + chunk := oss.FileChunk{} + for i := int64(0); i < chunkN; i++ { + chunk.Number = int(i + 1) + chunk.Offset = i * chunkSize + chunk.Size = chunkSize + chunks = append(chunks, chunk) + } + + if fileSize%chunkSize > 0 { + chunk.Number = len(chunks) + 1 + chunk.Offset = int64(len(chunks)) * chunkSize + chunk.Size = fileSize % chunkSize + chunks = append(chunks, chunk) + } + + return chunks, nil +} diff --git a/drivers/115_share/driver.go b/drivers/115_share/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..886a369c1b8bfca3fc8589624fe81f147e640602 --- /dev/null +++ b/drivers/115_share/driver.go @@ -0,0 +1,112 @@ +package _115_share + +import ( + "context" + + driver115 "github.com/SheltonZhu/115driver/pkg/driver" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "golang.org/x/time/rate" +) + +type Pan115Share struct { + model.Storage + Addition + client *driver115.Pan115Client + limiter *rate.Limiter +} + +func (d *Pan115Share) Config() driver.Config { + return config +} + +func (d *Pan115Share) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Pan115Share) Init(ctx context.Context) error { + if d.LimitRate > 0 { + d.limiter = rate.NewLimiter(rate.Limit(d.LimitRate), 1) + } + + return d.login() +} + +func (d *Pan115Share) WaitLimit(ctx context.Context) error { + if d.limiter != nil { + return d.limiter.Wait(ctx) + } + return nil +} + +func (d *Pan115Share) Drop(ctx context.Context) error { + return nil +} + +func (d *Pan115Share) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if err := d.WaitLimit(ctx); err != nil { + return nil, err + } + + files := make([]driver115.ShareFile, 0) + fileResp, err := d.client.GetShareSnap(d.ShareCode, d.ReceiveCode, dir.GetID(), driver115.QueryLimit(int(d.PageSize))) + if err != nil { + return nil, err + } + files = append(files, fileResp.Data.List...) + total := fileResp.Data.Count + count := len(fileResp.Data.List) + for total > count { + fileResp, err := d.client.GetShareSnap( + d.ShareCode, d.ReceiveCode, dir.GetID(), + driver115.QueryLimit(int(d.PageSize)), driver115.QueryOffset(count), + ) + if err != nil { + return nil, err + } + files = append(files, fileResp.Data.List...) + count += len(fileResp.Data.List) + } + + return utils.SliceConvert(files, transFunc) +} + +func (d *Pan115Share) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if err := d.WaitLimit(ctx); err != nil { + return nil, err + } + downloadInfo, err := d.client.DownloadByShareCode(d.ShareCode, d.ReceiveCode, file.GetID()) + if err != nil { + return nil, err + } + + return &model.Link{URL: downloadInfo.URL.URL}, nil +} + +func (d *Pan115Share) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + return errs.NotSupport +} + +func (d *Pan115Share) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotSupport +} + +func (d *Pan115Share) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + return errs.NotSupport +} + +func (d *Pan115Share) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotSupport +} + +func (d *Pan115Share) Remove(ctx context.Context, obj model.Obj) error { + return errs.NotSupport +} + +func (d *Pan115Share) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + return errs.NotSupport +} + +var _ driver.Driver = (*Pan115Share)(nil) diff --git a/drivers/115_share/meta.go b/drivers/115_share/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..b3d2cc1fad7b7243f2e07b392935e24485d2bc3f --- /dev/null +++ b/drivers/115_share/meta.go @@ -0,0 +1,34 @@ +package _115_share + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + Cookie string `json:"cookie" type:"text" help:"one of QR code token and cookie required"` + QRCodeToken string `json:"qrcode_token" type:"text" help:"one of QR code token and cookie required"` + QRCodeSource string `json:"qrcode_source" type:"select" options:"web,android,ios,tv,alipaymini,wechatmini,qandroid" default:"linux" help:"select the QR code device, default linux"` + PageSize int64 `json:"page_size" type:"number" default:"1000" help:"list api per page size of 115 driver"` + LimitRate float64 `json:"limit_rate" type:"float" default:"2" help:"limit all api request rate (1r/[limit_rate]s)"` + ShareCode string `json:"share_code" type:"text" required:"true" help:"share code of 115 share link"` + ReceiveCode string `json:"receive_code" type:"text" required:"true" help:"receive code of 115 share link"` + driver.RootID +} + +var config = driver.Config{ + Name: "115 Share", + DefaultRoot: "", + // OnlyProxy: true, + // OnlyLocal: true, + CheckStatus: false, + Alert: "", + NoOverwriteUpload: true, + NoUpload: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Pan115Share{} + }) +} diff --git a/drivers/115_share/utils.go b/drivers/115_share/utils.go new file mode 100644 index 0000000000000000000000000000000000000000..1f9e112deef04c4f788b78d84ded0046d594161c --- /dev/null +++ b/drivers/115_share/utils.go @@ -0,0 +1,111 @@ +package _115_share + +import ( + "fmt" + "strconv" + "time" + + driver115 "github.com/SheltonZhu/115driver/pkg/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" +) + +var _ model.Obj = (*FileObj)(nil) + +type FileObj struct { + Size int64 + Sha1 string + Utm time.Time + FileName string + isDir bool + FileID string +} + +func (f *FileObj) CreateTime() time.Time { + return f.Utm +} + +func (f *FileObj) GetHash() utils.HashInfo { + return utils.NewHashInfo(utils.SHA1, f.Sha1) +} + +func (f *FileObj) GetSize() int64 { + return f.Size +} + +func (f *FileObj) GetName() string { + return f.FileName +} + +func (f *FileObj) ModTime() time.Time { + return f.Utm +} + +func (f *FileObj) IsDir() bool { + return f.isDir +} + +func (f *FileObj) GetID() string { + return f.FileID +} + +func (f *FileObj) GetPath() string { + return "" +} + +func transFunc(sf driver115.ShareFile) (model.Obj, error) { + timeInt, err := strconv.ParseInt(sf.UpdateTime, 10, 64) + if err != nil { + return nil, err + } + var ( + utm = time.Unix(timeInt, 0) + isDir = (sf.IsFile == 0) + fileID = string(sf.FileID) + ) + if isDir { + fileID = string(sf.CategoryID) + } + return &FileObj{ + Size: int64(sf.Size), + Sha1: sf.Sha1, + Utm: utm, + FileName: string(sf.FileName), + isDir: isDir, + FileID: fileID, + }, nil +} + +var UserAgent = driver115.UA115Browser + +func (d *Pan115Share) login() error { + var err error + opts := []driver115.Option{ + driver115.UA(UserAgent), + } + d.client = driver115.New(opts...) + if _, err := d.client.GetShareSnap(d.ShareCode, d.ReceiveCode, ""); err != nil { + return errors.Wrap(err, "failed to get share snap") + } + cr := &driver115.Credential{} + if d.QRCodeToken != "" { + s := &driver115.QRCodeSession{ + UID: d.QRCodeToken, + } + if cr, err = d.client.QRCodeLoginWithApp(s, driver115.LoginApp(d.QRCodeSource)); err != nil { + return errors.Wrap(err, "failed to login by qrcode") + } + d.Cookie = fmt.Sprintf("UID=%s;CID=%s;SEID=%s;KID=%s", cr.UID, cr.CID, cr.SEID, cr.KID) + d.QRCodeToken = "" + } else if d.Cookie != "" { + if err = cr.FromCookie(d.Cookie); err != nil { + return errors.Wrap(err, "failed to login by cookies") + } + d.client.ImportCredential(cr) + } else { + return errors.New("missing cookie or qrcode account") + } + + return d.client.LoginCheck() +} diff --git a/drivers/123/driver.go b/drivers/123/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..3620431d9b3662b4a5ca6eb7fab46724d81742b7 --- /dev/null +++ b/drivers/123/driver.go @@ -0,0 +1,267 @@ +package _123 + +import ( + "context" + "crypto/md5" + "encoding/base64" + "encoding/hex" + "fmt" + "golang.org/x/time/rate" + "io" + "net/http" + "net/url" + "sync" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +type Pan123 struct { + model.Storage + Addition + apiRateLimit sync.Map +} + +func (d *Pan123) Config() driver.Config { + return config +} + +func (d *Pan123) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Pan123) Init(ctx context.Context) error { + _, err := d.request(UserInfo, http.MethodGet, nil, nil) + return err +} + +func (d *Pan123) Drop(ctx context.Context) error { + _, _ = d.request(Logout, http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{}) + }, nil) + return nil +} + +func (d *Pan123) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(ctx, dir.GetID(), dir.GetName()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return src, nil + }) +} + +func (d *Pan123) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if f, ok := file.(File); ok { + //var resp DownResp + var headers map[string]string + if !utils.IsLocalIPAddr(args.IP) { + headers = map[string]string{ + //"X-Real-IP": "1.1.1.1", + "X-Forwarded-For": args.IP, + } + } + data := base.Json{ + "driveId": 0, + "etag": f.Etag, + "fileId": f.FileId, + "fileName": f.FileName, + "s3keyFlag": f.S3KeyFlag, + "size": f.Size, + "type": f.Type, + } + resp, err := d.request(DownloadInfo, http.MethodPost, func(req *resty.Request) { + + req.SetBody(data).SetHeaders(headers) + }, nil) + if err != nil { + return nil, err + } + downloadUrl := utils.Json.Get(resp, "data", "DownloadUrl").ToString() + u, err := url.Parse(downloadUrl) + if err != nil { + return nil, err + } + nu := u.Query().Get("params") + if nu != "" { + du, _ := base64.StdEncoding.DecodeString(nu) + u, err = url.Parse(string(du)) + if err != nil { + return nil, err + } + } + u_ := u.String() + log.Debug("download url: ", u_) + res, err := base.NoRedirectClient.R().SetHeader("Referer", "https://www.123pan.com/").Get(u_) + if err != nil { + return nil, err + } + log.Debug(res.String()) + link := model.Link{ + URL: u_, + } + log.Debugln("res code: ", res.StatusCode()) + if res.StatusCode() == 302 { + link.URL = res.Header().Get("location") + } else if res.StatusCode() < 300 { + link.URL = utils.Json.Get(res.Body(), "data", "redirect_url").ToString() + } + link.Header = http.Header{ + "Referer": []string{"https://www.123pan.com/"}, + } + return &link, nil + } else { + return nil, fmt.Errorf("can't convert obj") + } +} + +func (d *Pan123) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + data := base.Json{ + "driveId": 0, + "etag": "", + "fileName": dirName, + "parentFileId": parentDir.GetID(), + "size": 0, + "type": 1, + } + _, err := d.request(Mkdir, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *Pan123) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + data := base.Json{ + "fileIdList": []base.Json{{"FileId": srcObj.GetID()}}, + "parentFileId": dstDir.GetID(), + } + _, err := d.request(Move, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *Pan123) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + data := base.Json{ + "driveId": 0, + "fileId": srcObj.GetID(), + "fileName": newName, + } + _, err := d.request(Rename, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *Pan123) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotSupport +} + +func (d *Pan123) Remove(ctx context.Context, obj model.Obj) error { + if f, ok := obj.(File); ok { + data := base.Json{ + "driveId": 0, + "operation": true, + "fileTrashInfoList": []File{f}, + } + _, err := d.request(Trash, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err + } else { + return fmt.Errorf("can't convert obj") + } +} + +func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + // const DEFAULT int64 = 10485760 + h := md5.New() + // need to calculate md5 of the full content + tempFile, err := stream.CacheFullInTempFile() + if err != nil { + return err + } + defer func() { + _ = tempFile.Close() + }() + if _, err = utils.CopyWithBuffer(h, tempFile); err != nil { + return err + } + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return err + } + etag := hex.EncodeToString(h.Sum(nil)) + data := base.Json{ + "driveId": 0, + "duplicate": 2, // 2->覆盖 1->重命名 0->默认 + "etag": etag, + "fileName": stream.GetName(), + "parentFileId": dstDir.GetID(), + "size": stream.GetSize(), + "type": 0, + } + var resp UploadResp + res, err := d.request(UploadRequest, http.MethodPost, func(req *resty.Request) { + req.SetBody(data).SetContext(ctx) + }, &resp) + if err != nil { + return err + } + log.Debugln("upload request res: ", string(res)) + if resp.Data.Reuse || resp.Data.Key == "" { + return nil + } + if resp.Data.AccessKeyId == "" || resp.Data.SecretAccessKey == "" || resp.Data.SessionToken == "" { + err = d.newUpload(ctx, &resp, stream, tempFile, up) + return err + } else { + cfg := &aws.Config{ + Credentials: credentials.NewStaticCredentials(resp.Data.AccessKeyId, resp.Data.SecretAccessKey, resp.Data.SessionToken), + Region: aws.String("123pan"), + Endpoint: aws.String(resp.Data.EndPoint), + S3ForcePathStyle: aws.Bool(true), + } + s, err := session.NewSession(cfg) + if err != nil { + return err + } + uploader := s3manager.NewUploader(s) + if stream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { + uploader.PartSize = stream.GetSize() / (s3manager.MaxUploadParts - 1) + } + input := &s3manager.UploadInput{ + Bucket: &resp.Data.Bucket, + Key: &resp.Data.Key, + Body: tempFile, + } + _, err = uploader.UploadWithContext(ctx, input) + } + _, err = d.request(UploadComplete, http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "fileId": resp.Data.FileId, + }).SetContext(ctx) + }, nil) + return err +} + +func (d *Pan123) APIRateLimit(ctx context.Context, api string) error { + value, _ := d.apiRateLimit.LoadOrStore(api, + rate.NewLimiter(rate.Every(700*time.Millisecond), 1)) + limiter := value.(*rate.Limiter) + + return limiter.Wait(ctx) +} + +var _ driver.Driver = (*Pan123)(nil) diff --git a/drivers/123/meta.go b/drivers/123/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..cb2cbc15ba07c17576efaf39309478d2d4667281 --- /dev/null +++ b/drivers/123/meta.go @@ -0,0 +1,27 @@ +package _123 + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + Username string `json:"username" required:"true"` + Password string `json:"password" required:"true"` + driver.RootID + //OrderBy string `json:"order_by" type:"select" options:"file_id,file_name,size,update_at" default:"file_name"` + //OrderDirection string `json:"order_direction" type:"select" options:"asc,desc" default:"asc"` + AccessToken string +} + +var config = driver.Config{ + Name: "123Pan", + DefaultRoot: "0", + LocalSort: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Pan123{} + }) +} diff --git a/drivers/123/types.go b/drivers/123/types.go new file mode 100644 index 0000000000000000000000000000000000000000..a8682c52fc9564f4fcef75392fae31030874c3e1 --- /dev/null +++ b/drivers/123/types.go @@ -0,0 +1,123 @@ +package _123 + +import ( + "github.com/alist-org/alist/v3/pkg/utils" + "net/url" + "path" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type File struct { + FileName string `json:"FileName"` + Size int64 `json:"Size"` + UpdateAt time.Time `json:"UpdateAt"` + FileId int64 `json:"FileId"` + Type int `json:"Type"` + Etag string `json:"Etag"` + S3KeyFlag string `json:"S3KeyFlag"` + DownloadUrl string `json:"DownloadUrl"` +} + +func (f File) CreateTime() time.Time { + return f.UpdateAt +} + +func (f File) GetHash() utils.HashInfo { + return utils.HashInfo{} +} + +func (f File) GetPath() string { + return "" +} + +func (f File) GetSize() int64 { + return f.Size +} + +func (f File) GetName() string { + return f.FileName +} + +func (f File) ModTime() time.Time { + return f.UpdateAt +} + +func (f File) IsDir() bool { + return f.Type == 1 +} + +func (f File) GetID() string { + return strconv.FormatInt(f.FileId, 10) +} + +func (f File) Thumb() string { + if f.DownloadUrl == "" { + return "" + } + du, err := url.Parse(f.DownloadUrl) + if err != nil { + return "" + } + du.Path = strings.TrimSuffix(du.Path, "_24_24") + "_70_70" + query := du.Query() + query.Set("w", "70") + query.Set("h", "70") + if !query.Has("type") { + query.Set("type", strings.TrimPrefix(path.Base(f.FileName), ".")) + } + if !query.Has("trade_key") { + query.Set("trade_key", "123pan-thumbnail") + } + du.RawQuery = query.Encode() + return du.String() +} + +var _ model.Obj = (*File)(nil) +var _ model.Thumb = (*File)(nil) + +//func (f File) Thumb() string { +// +//} +//var _ model.Thumb = (*File)(nil) + +type Files struct { + //BaseResp + Data struct { + Next string `json:"Next"` + Total int `json:"Total"` + InfoList []File `json:"InfoList"` + } `json:"data"` +} + +//type DownResp struct { +// //BaseResp +// Data struct { +// DownloadUrl string `json:"DownloadUrl"` +// } `json:"data"` +//} + +type UploadResp struct { + //BaseResp + Data struct { + AccessKeyId string `json:"AccessKeyId"` + Bucket string `json:"Bucket"` + Key string `json:"Key"` + SecretAccessKey string `json:"SecretAccessKey"` + SessionToken string `json:"SessionToken"` + FileId int64 `json:"FileId"` + Reuse bool `json:"Reuse"` + EndPoint string `json:"EndPoint"` + StorageNode string `json:"StorageNode"` + UploadId string `json:"UploadId"` + } `json:"data"` +} + +type S3PreSignedURLs struct { + Data struct { + PreSignedUrls map[string]string `json:"presignedUrls"` + } `json:"data"` +} diff --git a/drivers/123/upload.go b/drivers/123/upload.go new file mode 100644 index 0000000000000000000000000000000000000000..6f6221f11487bd14a3ef9d5b613705348e85bc6f --- /dev/null +++ b/drivers/123/upload.go @@ -0,0 +1,155 @@ +package _123 + +import ( + "context" + "fmt" + "io" + "math" + "net/http" + "strconv" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +func (d *Pan123) getS3PreSignedUrls(ctx context.Context, upReq *UploadResp, start, end int) (*S3PreSignedURLs, error) { + data := base.Json{ + "bucket": upReq.Data.Bucket, + "key": upReq.Data.Key, + "partNumberEnd": end, + "partNumberStart": start, + "uploadId": upReq.Data.UploadId, + "StorageNode": upReq.Data.StorageNode, + } + var s3PreSignedUrls S3PreSignedURLs + _, err := d.request(S3PreSignedUrls, http.MethodPost, func(req *resty.Request) { + req.SetBody(data).SetContext(ctx) + }, &s3PreSignedUrls) + if err != nil { + return nil, err + } + return &s3PreSignedUrls, nil +} + +func (d *Pan123) getS3Auth(ctx context.Context, upReq *UploadResp, start, end int) (*S3PreSignedURLs, error) { + data := base.Json{ + "StorageNode": upReq.Data.StorageNode, + "bucket": upReq.Data.Bucket, + "key": upReq.Data.Key, + "partNumberEnd": end, + "partNumberStart": start, + "uploadId": upReq.Data.UploadId, + } + var s3PreSignedUrls S3PreSignedURLs + _, err := d.request(S3Auth, http.MethodPost, func(req *resty.Request) { + req.SetBody(data).SetContext(ctx) + }, &s3PreSignedUrls) + if err != nil { + return nil, err + } + return &s3PreSignedUrls, nil +} + +func (d *Pan123) completeS3(ctx context.Context, upReq *UploadResp, file model.FileStreamer, isMultipart bool) error { + data := base.Json{ + "StorageNode": upReq.Data.StorageNode, + "bucket": upReq.Data.Bucket, + "fileId": upReq.Data.FileId, + "fileSize": file.GetSize(), + "isMultipart": isMultipart, + "key": upReq.Data.Key, + "uploadId": upReq.Data.UploadId, + } + _, err := d.request(UploadCompleteV2, http.MethodPost, func(req *resty.Request) { + req.SetBody(data).SetContext(ctx) + }, nil) + return err +} + +func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.FileStreamer, reader io.Reader, up driver.UpdateProgress) error { + chunkSize := int64(1024 * 1024 * 16) + // fetch s3 pre signed urls + chunkCount := int(math.Ceil(float64(file.GetSize()) / float64(chunkSize))) + // only 1 batch is allowed + isMultipart := chunkCount > 1 + batchSize := 1 + getS3UploadUrl := d.getS3Auth + if isMultipart { + batchSize = 10 + getS3UploadUrl = d.getS3PreSignedUrls + } + for i := 1; i <= chunkCount; i += batchSize { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + start := i + end := i + batchSize + if end > chunkCount+1 { + end = chunkCount + 1 + } + s3PreSignedUrls, err := getS3UploadUrl(ctx, upReq, start, end) + if err != nil { + return err + } + // upload each chunk + for j := start; j < end; j++ { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + curSize := chunkSize + if j == chunkCount { + curSize = file.GetSize() - (int64(chunkCount)-1)*chunkSize + } + err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.LimitReader(reader, chunkSize), curSize, false, getS3UploadUrl) + if err != nil { + return err + } + up(float64(j) * 100 / float64(chunkCount)) + } + } + // complete s3 upload + return d.completeS3(ctx, upReq, file, chunkCount > 1) +} + +func (d *Pan123) uploadS3Chunk(ctx context.Context, upReq *UploadResp, s3PreSignedUrls *S3PreSignedURLs, cur, end int, reader io.Reader, curSize int64, retry bool, getS3UploadUrl func(ctx context.Context, upReq *UploadResp, start int, end int) (*S3PreSignedURLs, error)) error { + uploadUrl := s3PreSignedUrls.Data.PreSignedUrls[strconv.Itoa(cur)] + if uploadUrl == "" { + return fmt.Errorf("upload url is empty, s3PreSignedUrls: %+v", s3PreSignedUrls) + } + req, err := http.NewRequest("PUT", uploadUrl, reader) + if err != nil { + return err + } + req = req.WithContext(ctx) + req.ContentLength = curSize + //req.Header.Set("Content-Length", strconv.FormatInt(curSize, 10)) + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode == http.StatusForbidden { + if retry { + return fmt.Errorf("upload s3 chunk %d failed, status code: %d", cur, res.StatusCode) + } + // refresh s3 pre signed urls + newS3PreSignedUrls, err := getS3UploadUrl(ctx, upReq, cur, end) + if err != nil { + return err + } + s3PreSignedUrls.Data.PreSignedUrls = newS3PreSignedUrls.Data.PreSignedUrls + // retry + return d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, cur, end, reader, curSize, true, getS3UploadUrl) + } + if res.StatusCode != http.StatusOK { + body, err := io.ReadAll(res.Body) + if err != nil { + return err + } + return fmt.Errorf("upload s3 chunk %d failed, status code: %d, body: %s", cur, res.StatusCode, body) + } + return nil +} diff --git a/drivers/123/util.go b/drivers/123/util.go new file mode 100644 index 0000000000000000000000000000000000000000..6365b1c9a1e65bb886709a28ab65606e0ceb88e5 --- /dev/null +++ b/drivers/123/util.go @@ -0,0 +1,281 @@ +package _123 + +import ( + "context" + "errors" + "fmt" + "hash/crc32" + "math" + "math/rand" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + jsoniter "github.com/json-iterator/go" + log "github.com/sirupsen/logrus" +) + +// do others that not defined in Driver interface + +const ( + Api = "https://www.123pan.com/api" + AApi = "https://www.123pan.com/a/api" + BApi = "https://www.123pan.com/b/api" + LoginApi = "https://login.123pan.com/api" + MainApi = BApi + SignIn = LoginApi + "/user/sign_in" + Logout = MainApi + "/user/logout" + UserInfo = MainApi + "/user/info" + FileList = MainApi + "/file/list/new" + DownloadInfo = MainApi + "/file/download_info" + Mkdir = MainApi + "/file/upload_request" + Move = MainApi + "/file/mod_pid" + Rename = MainApi + "/file/rename" + Trash = MainApi + "/file/trash" + UploadRequest = MainApi + "/file/upload_request" + UploadComplete = MainApi + "/file/upload_complete" + S3PreSignedUrls = MainApi + "/file/s3_repare_upload_parts_batch" + S3Auth = MainApi + "/file/s3_upload_object/auth" + UploadCompleteV2 = MainApi + "/file/upload_complete/v2" + S3Complete = MainApi + "/file/s3_complete_multipart_upload" + //AuthKeySalt = "8-8D$sL8gPjom7bk#cY" +) + +func signPath(path string, os string, version string) (k string, v string) { + table := []byte{'a', 'd', 'e', 'f', 'g', 'h', 'l', 'm', 'y', 'i', 'j', 'n', 'o', 'p', 'k', 'q', 'r', 's', 't', 'u', 'b', 'c', 'v', 'w', 's', 'z'} + random := fmt.Sprintf("%.f", math.Round(1e7*rand.Float64())) + now := time.Now().In(time.FixedZone("CST", 8*3600)) + timestamp := fmt.Sprint(now.Unix()) + nowStr := []byte(now.Format("200601021504")) + for i := 0; i < len(nowStr); i++ { + nowStr[i] = table[nowStr[i]-48] + } + timeSign := fmt.Sprint(crc32.ChecksumIEEE(nowStr)) + data := strings.Join([]string{timestamp, random, path, os, version, timeSign}, "|") + dataSign := fmt.Sprint(crc32.ChecksumIEEE([]byte(data))) + return timeSign, strings.Join([]string{timestamp, random, dataSign}, "-") +} + +func GetApi(rawUrl string) string { + u, _ := url.Parse(rawUrl) + query := u.Query() + query.Add(signPath(u.Path, "web", "3")) + u.RawQuery = query.Encode() + return u.String() +} + +//func GetApi(url string) string { +// vm := js.New() +// vm.Set("url", url[22:]) +// r, err := vm.RunString(` +// (function(e){ +// function A(t, e) { +// e = 1 < arguments.length && void 0 !== e ? e : 10; +// for (var n = function() { +// for (var t = [], e = 0; e < 256; e++) { +// for (var n = e, r = 0; r < 8; r++) +// n = 1 & n ? 3988292384 ^ n >>> 1 : n >>> 1; +// t[e] = n +// } +// return t +// }(), r = function(t) { +// t = t.replace(/\\r\\n/g, "\\n"); +// for (var e = "", n = 0; n < t.length; n++) { +// var r = t.charCodeAt(n); +// r < 128 ? e += String.fromCharCode(r) : e = 127 < r && r < 2048 ? (e += String.fromCharCode(r >> 6 | 192)) + String.fromCharCode(63 & r | 128) : (e = (e += String.fromCharCode(r >> 12 | 224)) + String.fromCharCode(r >> 6 & 63 | 128)) + String.fromCharCode(63 & r | 128) +// } +// return e +// }(t), a = -1, i = 0; i < r.length; i++) +// a = a >>> 8 ^ n[255 & (a ^ r.charCodeAt(i))]; +// return (a = (-1 ^ a) >>> 0).toString(e) +// } +// +// function v(t) { +// return (v = "function" == typeof Symbol && "symbol" == typeof Symbol.iterator ? function(t) { +// return typeof t +// } +// : function(t) { +// return t && "function" == typeof Symbol && t.constructor === Symbol && t !== Symbol.prototype ? "symbol" : typeof t +// } +// )(t) +// } +// +// for (p in a = Math.round(1e7 * Math.random()), +// o = Math.round(((new Date).getTime() + 60 * (new Date).getTimezoneOffset() * 1e3 + 288e5) / 1e3).toString(), +// m = ["a", "d", "e", "f", "g", "h", "l", "m", "y", "i", "j", "n", "o", "p", "k", "q", "r", "s", "t", "u", "b", "c", "v", "w", "s", "z"], +// u = function(t, e, n) { +// var r; +// n = 2 < arguments.length && void 0 !== n ? n : 8; +// return 0 === arguments.length ? null : (r = "object" === v(t) ? t : (10 === "".concat(t).length && (t = 1e3 * Number.parseInt(t)), +// new Date(t)), +// t += 6e4 * new Date(t).getTimezoneOffset(), +// { +// y: (r = new Date(t + 36e5 * n)).getFullYear(), +// m: r.getMonth() + 1 < 10 ? "0".concat(r.getMonth() + 1) : r.getMonth() + 1, +// d: r.getDate() < 10 ? "0".concat(r.getDate()) : r.getDate(), +// h: r.getHours() < 10 ? "0".concat(r.getHours()) : r.getHours(), +// f: r.getMinutes() < 10 ? "0".concat(r.getMinutes()) : r.getMinutes() +// }) +// }(o), +// h = u.y, +// g = u.m, +// l = u.d, +// c = u.h, +// u = u.f, +// d = [h, g, l, c, u].join(""), +// f = [], +// d) +// f.push(m[Number(d[p])]); +// return h = A(f.join("")), +// g = A("".concat(o, "|").concat(a, "|").concat(e, "|").concat("web", "|").concat("3", "|").concat(h)), +// "".concat(h, "=").concat(o, "-").concat(a, "-").concat(g); +// })(url) +// `) +// if err != nil { +// fmt.Println(err) +// return url +// } +// v, _ := r.Export().(string) +// return url + "?" + v +//} + +func (d *Pan123) login() error { + var body base.Json + if utils.IsEmailFormat(d.Username) { + body = base.Json{ + "mail": d.Username, + "password": d.Password, + "type": 2, + } + } else { + body = base.Json{ + "passport": d.Username, + "password": d.Password, + "remember": true, + } + } + res, err := base.RestyClient.R(). + SetHeaders(map[string]string{ + "origin": "https://www.123pan.com", + "referer": "https://www.123pan.com/", + "user-agent": "Dart/2.19(dart:io)-alist", + "platform": "web", + "app-version": "3", + //"user-agent": base.UserAgent, + }). + SetBody(body).Post(SignIn) + if err != nil { + return err + } + if utils.Json.Get(res.Body(), "code").ToInt() != 200 { + err = fmt.Errorf(utils.Json.Get(res.Body(), "message").ToString()) + } else { + d.AccessToken = utils.Json.Get(res.Body(), "data", "token").ToString() + } + return err +} + +//func authKey(reqUrl string) (*string, error) { +// reqURL, err := url.Parse(reqUrl) +// if err != nil { +// return nil, err +// } +// +// nowUnix := time.Now().Unix() +// random := rand.Intn(0x989680) +// +// p4 := fmt.Sprintf("%d|%d|%s|%s|%s|%s", nowUnix, random, reqURL.Path, "web", "3", AuthKeySalt) +// authKey := fmt.Sprintf("%d-%d-%x", nowUnix, random, md5.Sum([]byte(p4))) +// return &authKey, nil +//} + +func (d *Pan123) request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "origin": "https://www.123pan.com", + "referer": "https://www.123pan.com/", + "authorization": "Bearer " + d.AccessToken, + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) alist-client", + "platform": "web", + "app-version": "3", + //"user-agent": base.UserAgent, + }) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + //authKey, err := authKey(url) + //if err != nil { + // return nil, err + //} + //req.SetQueryParam("auth-key", *authKey) + res, err := req.Execute(method, GetApi(url)) + if err != nil { + return nil, err + } + body := res.Body() + code := utils.Json.Get(body, "code").ToInt() + if code != 0 { + if code == 401 { + err := d.login() + if err != nil { + return nil, err + } + return d.request(url, method, callback, resp) + } + return nil, errors.New(jsoniter.Get(body, "message").ToString()) + } + return body, nil +} + +func (d *Pan123) getFiles(ctx context.Context, parentId string, name string) ([]File, error) { + page := 1 + total := 0 + res := make([]File, 0) + // 2024-02-06 fix concurrency by 123pan + for { + if err := d.APIRateLimit(ctx, FileList); err != nil { + return nil, err + } + var resp Files + query := map[string]string{ + "driveId": "0", + "limit": "100", + "next": "0", + "orderBy": "file_id", + "orderDirection": "desc", + "parentFileId": parentId, + "trashed": "false", + "SearchData": "", + "Page": strconv.Itoa(page), + "OnlyLookAbnormalFile": "0", + "event": "homeListFile", + "operateType": "4", + "inDirectSpace": "false", + } + _res, err := d.request(FileList, http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return nil, err + } + log.Debug(string(_res)) + page++ + res = append(res, resp.Data.InfoList...) + total = resp.Data.Total + if len(resp.Data.InfoList) == 0 || resp.Data.Next == "-1" { + break + } + } + if len(res) != total { + log.Warnf("incorrect file count from remote at %s: expected %d, got %d", name, total, len(res)) + } + return res, nil +} diff --git a/drivers/123_link/driver.go b/drivers/123_link/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..46cdcbae4827c9cc5e1ed1c52cc9dfe769d8ccbb --- /dev/null +++ b/drivers/123_link/driver.go @@ -0,0 +1,77 @@ +package _123Link + +import ( + "context" + stdpath "path" + "time" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" +) + +type Pan123Link struct { + model.Storage + Addition + root *Node +} + +func (d *Pan123Link) Config() driver.Config { + return config +} + +func (d *Pan123Link) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Pan123Link) Init(ctx context.Context) error { + node, err := BuildTree(d.OriginURLs) + if err != nil { + return err + } + node.calSize() + d.root = node + return nil +} + +func (d *Pan123Link) Drop(ctx context.Context) error { + return nil +} + +func (d *Pan123Link) Get(ctx context.Context, path string) (model.Obj, error) { + node := GetNodeFromRootByPath(d.root, path) + return nodeToObj(node, path) +} + +func (d *Pan123Link) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + node := GetNodeFromRootByPath(d.root, dir.GetPath()) + if node == nil { + return nil, errs.ObjectNotFound + } + if node.isFile() { + return nil, errs.NotFolder + } + return utils.SliceConvert(node.Children, func(node *Node) (model.Obj, error) { + return nodeToObj(node, stdpath.Join(dir.GetPath(), node.Name)) + }) +} + +func (d *Pan123Link) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + node := GetNodeFromRootByPath(d.root, file.GetPath()) + if node == nil { + return nil, errs.ObjectNotFound + } + if node.isFile() { + signUrl, err := SignURL(node.Url, d.PrivateKey, d.UID, time.Duration(d.ValidDuration)*time.Minute) + if err != nil { + return nil, err + } + return &model.Link{ + URL: signUrl, + }, nil + } + return nil, errs.NotFile +} + +var _ driver.Driver = (*Pan123Link)(nil) diff --git a/drivers/123_link/meta.go b/drivers/123_link/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..9f35762869116205a29c2c2c63cdb278ce2ed620 --- /dev/null +++ b/drivers/123_link/meta.go @@ -0,0 +1,23 @@ +package _123Link + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + OriginURLs string `json:"origin_urls" type:"text" required:"true" default:"https://vip.123pan.com/29/folder/file.mp3" help:"structure:FolderName:\n [FileSize:][Modified:]Url"` + PrivateKey string `json:"private_key"` + UID uint64 `json:"uid" type:"number"` + ValidDuration int64 `json:"valid_duration" type:"number" default:"30" help:"minutes"` +} + +var config = driver.Config{ + Name: "123PanLink", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Pan123Link{} + }) +} diff --git a/drivers/123_link/parse.go b/drivers/123_link/parse.go new file mode 100644 index 0000000000000000000000000000000000000000..8d6c3a13d02a02cdc1de320bd8118ab9a89d3a3d --- /dev/null +++ b/drivers/123_link/parse.go @@ -0,0 +1,152 @@ +package _123Link + +import ( + "fmt" + url2 "net/url" + stdpath "path" + "strconv" + "strings" + "time" +) + +// build tree from text, text structure definition: +/** + * FolderName: + * [FileSize:][Modified:]Url + */ +/** + * For example: + * folder1: + * name1:url1 + * url2 + * folder2: + * url3 + * url4 + * url5 + * folder3: + * url6 + * url7 + * url8 + */ +// if there are no name, use the last segment of url as name +func BuildTree(text string) (*Node, error) { + lines := strings.Split(text, "\n") + var root = &Node{Level: -1, Name: "root"} + stack := []*Node{root} + for _, line := range lines { + // calculate indent + indent := 0 + for i := 0; i < len(line); i++ { + if line[i] != ' ' { + break + } + indent++ + } + // if indent is not a multiple of 2, it is an error + if indent%2 != 0 { + return nil, fmt.Errorf("the line '%s' is not a multiple of 2", line) + } + // calculate level + level := indent / 2 + line = strings.TrimSpace(line[indent:]) + // if the line is empty, skip + if line == "" { + continue + } + // if level isn't greater than the level of the top of the stack + // it is not the child of the top of the stack + for level <= stack[len(stack)-1].Level { + // pop the top of the stack + stack = stack[:len(stack)-1] + } + // if the line is a folder + if isFolder(line) { + // create a new node + node := &Node{ + Level: level, + Name: strings.TrimSuffix(line, ":"), + } + // add the node to the top of the stack + stack[len(stack)-1].Children = append(stack[len(stack)-1].Children, node) + // push the node to the stack + stack = append(stack, node) + } else { + // if the line is a file + // create a new node + node, err := parseFileLine(line) + if err != nil { + return nil, err + } + node.Level = level + // add the node to the top of the stack + stack[len(stack)-1].Children = append(stack[len(stack)-1].Children, node) + } + } + return root, nil +} + +func isFolder(line string) bool { + return strings.HasSuffix(line, ":") +} + +// line definition: +// [FileSize:][Modified:]Url +func parseFileLine(line string) (*Node, error) { + // if there is no url, it is an error + if !strings.Contains(line, "http://") && !strings.Contains(line, "https://") { + return nil, fmt.Errorf("invalid line: %s, because url is required for file", line) + } + index := strings.Index(line, "http://") + if index == -1 { + index = strings.Index(line, "https://") + } + url := line[index:] + info := line[:index] + node := &Node{ + Url: url, + } + name := stdpath.Base(url) + unescape, err := url2.PathUnescape(name) + if err == nil { + name = unescape + } + node.Name = name + if index > 0 { + if !strings.HasSuffix(info, ":") { + return nil, fmt.Errorf("invalid line: %s, because file info must end with ':'", line) + } + info = info[:len(info)-1] + if info == "" { + return nil, fmt.Errorf("invalid line: %s, because file name can't be empty", line) + } + infoParts := strings.Split(info, ":") + size, err := strconv.ParseInt(infoParts[0], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid line: %s, because file size must be an integer", line) + } + node.Size = size + if len(infoParts) > 1 { + modified, err := strconv.ParseInt(infoParts[1], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid line: %s, because file modified must be an unix timestamp", line) + } + node.Modified = modified + } else { + node.Modified = time.Now().Unix() + } + } + return node, nil +} + +func splitPath(path string) []string { + if path == "/" { + return []string{"root"} + } + parts := strings.Split(path, "/") + parts[0] = "root" + return parts +} + +func GetNodeFromRootByPath(root *Node, path string) *Node { + return root.getByPath(splitPath(path)) +} diff --git a/drivers/123_link/types.go b/drivers/123_link/types.go new file mode 100644 index 0000000000000000000000000000000000000000..3fb040eb831f686d6d072c68e4c04f3be13ac95c --- /dev/null +++ b/drivers/123_link/types.go @@ -0,0 +1,66 @@ +package _123Link + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" +) + +// Node is a node in the folder tree +type Node struct { + Url string + Name string + Level int + Modified int64 + Size int64 + Children []*Node +} + +func (node *Node) getByPath(paths []string) *Node { + if len(paths) == 0 || node == nil { + return nil + } + if node.Name != paths[0] { + return nil + } + if len(paths) == 1 { + return node + } + for _, child := range node.Children { + tmp := child.getByPath(paths[1:]) + if tmp != nil { + return tmp + } + } + return nil +} + +func (node *Node) isFile() bool { + return node.Url != "" +} + +func (node *Node) calSize() int64 { + if node.isFile() { + return node.Size + } + var size int64 = 0 + for _, child := range node.Children { + size += child.calSize() + } + node.Size = size + return size +} + +func nodeToObj(node *Node, path string) (model.Obj, error) { + if node == nil { + return nil, errs.ObjectNotFound + } + return &model.Object{ + Name: node.Name, + Size: node.Size, + Modified: time.Unix(node.Modified, 0), + IsFolder: !node.isFile(), + Path: path, + }, nil +} diff --git a/drivers/123_link/util.go b/drivers/123_link/util.go new file mode 100644 index 0000000000000000000000000000000000000000..29c9b54d57621ffaa1fd004300043e4b35fc6c92 --- /dev/null +++ b/drivers/123_link/util.go @@ -0,0 +1,30 @@ +package _123Link + +import ( + "crypto/md5" + "fmt" + "math/rand" + "net/url" + "time" +) + +func SignURL(originURL, privateKey string, uid uint64, validDuration time.Duration) (newURL string, err error) { + if privateKey == "" { + return originURL, nil + } + var ( + ts = time.Now().Add(validDuration).Unix() // 有效时间戳 + rInt = rand.Int() // 随机正整数 + objURL *url.URL + ) + objURL, err = url.Parse(originURL) + if err != nil { + return "", err + } + authKey := fmt.Sprintf("%d-%d-%d-%x", ts, rInt, uid, md5.Sum([]byte(fmt.Sprintf("%s-%d-%d-%d-%s", + objURL.Path, ts, rInt, uid, privateKey)))) + v := objURL.Query() + v.Add("auth_key", authKey) + objURL.RawQuery = v.Encode() + return objURL.String(), nil +} diff --git a/drivers/123_share/driver.go b/drivers/123_share/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..9c1f3803710b86cd82df9d6872366227968701af --- /dev/null +++ b/drivers/123_share/driver.go @@ -0,0 +1,161 @@ +package _123Share + +import ( + "context" + "encoding/base64" + "fmt" + "golang.org/x/time/rate" + "net/http" + "net/url" + "sync" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +type Pan123Share struct { + model.Storage + Addition + apiRateLimit sync.Map +} + +func (d *Pan123Share) Config() driver.Config { + return config +} + +func (d *Pan123Share) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Pan123Share) Init(ctx context.Context) error { + // TODO login / refresh token + //op.MustSaveDriverStorage(d) + return nil +} + +func (d *Pan123Share) Drop(ctx context.Context) error { + return nil +} + +func (d *Pan123Share) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + // TODO return the files list, required + files, err := d.getFiles(ctx, dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return src, nil + }) +} + +func (d *Pan123Share) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + // TODO return link of file, required + if f, ok := file.(File); ok { + //var resp DownResp + var headers map[string]string + if !utils.IsLocalIPAddr(args.IP) { + headers = map[string]string{ + //"X-Real-IP": "1.1.1.1", + "X-Forwarded-For": args.IP, + } + } + data := base.Json{ + "shareKey": d.ShareKey, + "SharePwd": d.SharePwd, + "etag": f.Etag, + "fileId": f.FileId, + "s3keyFlag": f.S3KeyFlag, + "size": f.Size, + } + resp, err := d.request(DownloadInfo, http.MethodPost, func(req *resty.Request) { + req.SetBody(data).SetHeaders(headers) + }, nil) + if err != nil { + return nil, err + } + downloadUrl := utils.Json.Get(resp, "data", "DownloadURL").ToString() + u, err := url.Parse(downloadUrl) + if err != nil { + return nil, err + } + nu := u.Query().Get("params") + if nu != "" { + du, _ := base64.StdEncoding.DecodeString(nu) + u, err = url.Parse(string(du)) + if err != nil { + return nil, err + } + } + u_ := u.String() + log.Debug("download url: ", u_) + res, err := base.NoRedirectClient.R().SetHeader("Referer", "https://www.123pan.com/").Get(u_) + if err != nil { + return nil, err + } + log.Debug(res.String()) + link := model.Link{ + URL: u_, + } + log.Debugln("res code: ", res.StatusCode()) + if res.StatusCode() == 302 { + link.URL = res.Header().Get("location") + } else if res.StatusCode() < 300 { + link.URL = utils.Json.Get(res.Body(), "data", "redirect_url").ToString() + } + link.Header = http.Header{ + "Referer": []string{"https://www.123pan.com/"}, + } + return &link, nil + } + return nil, fmt.Errorf("can't convert obj") +} + +func (d *Pan123Share) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + // TODO create folder, optional + return errs.NotSupport +} + +func (d *Pan123Share) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + // TODO move obj, optional + return errs.NotSupport +} + +func (d *Pan123Share) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + // TODO rename obj, optional + return errs.NotSupport +} + +func (d *Pan123Share) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + // TODO copy obj, optional + return errs.NotSupport +} + +func (d *Pan123Share) Remove(ctx context.Context, obj model.Obj) error { + // TODO remove obj, optional + return errs.NotSupport +} + +func (d *Pan123Share) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + // TODO upload file, optional + return errs.NotSupport +} + +//func (d *Pan123Share) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +func (d *Pan123Share) APIRateLimit(ctx context.Context, api string) error { + value, _ := d.apiRateLimit.LoadOrStore(api, + rate.NewLimiter(rate.Every(700*time.Millisecond), 1)) + limiter := value.(*rate.Limiter) + + return limiter.Wait(ctx) +} + +var _ driver.Driver = (*Pan123Share)(nil) diff --git a/drivers/123_share/meta.go b/drivers/123_share/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..7cbcba277245f8392faf008b577f4177fcd65dbd --- /dev/null +++ b/drivers/123_share/meta.go @@ -0,0 +1,35 @@ +package _123Share + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + ShareKey string `json:"sharekey" required:"true"` + SharePwd string `json:"sharepassword"` + driver.RootID + //OrderBy string `json:"order_by" type:"select" options:"file_name,size,update_at" default:"file_name"` + //OrderDirection string `json:"order_direction" type:"select" options:"asc,desc" default:"asc"` + AccessToken string `json:"accesstoken" type:"text"` +} + +var config = driver.Config{ + Name: "123PanShare", + LocalSort: true, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: true, + NeedMs: false, + DefaultRoot: "0", + CheckStatus: false, + Alert: "", + NoOverwriteUpload: false, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Pan123Share{} + }) +} diff --git a/drivers/123_share/types.go b/drivers/123_share/types.go new file mode 100644 index 0000000000000000000000000000000000000000..e8ca9e7744065183731a9ee981fdf2f8132cafae --- /dev/null +++ b/drivers/123_share/types.go @@ -0,0 +1,99 @@ +package _123Share + +import ( + "github.com/alist-org/alist/v3/pkg/utils" + "net/url" + "path" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type File struct { + FileName string `json:"FileName"` + Size int64 `json:"Size"` + UpdateAt time.Time `json:"UpdateAt"` + FileId int64 `json:"FileId"` + Type int `json:"Type"` + Etag string `json:"Etag"` + S3KeyFlag string `json:"S3KeyFlag"` + DownloadUrl string `json:"DownloadUrl"` +} + +func (f File) GetHash() utils.HashInfo { + return utils.HashInfo{} +} + +func (f File) GetPath() string { + return "" +} + +func (f File) GetSize() int64 { + return f.Size +} + +func (f File) GetName() string { + return f.FileName +} + +func (f File) ModTime() time.Time { + return f.UpdateAt +} +func (f File) CreateTime() time.Time { + return f.UpdateAt +} + +func (f File) IsDir() bool { + return f.Type == 1 +} + +func (f File) GetID() string { + return strconv.FormatInt(f.FileId, 10) +} + +func (f File) Thumb() string { + if f.DownloadUrl == "" { + return "" + } + du, err := url.Parse(f.DownloadUrl) + if err != nil { + return "" + } + du.Path = strings.TrimSuffix(du.Path, "_24_24") + "_70_70" + query := du.Query() + query.Set("w", "70") + query.Set("h", "70") + if !query.Has("type") { + query.Set("type", strings.TrimPrefix(path.Base(f.FileName), ".")) + } + if !query.Has("trade_key") { + query.Set("trade_key", "123pan-thumbnail") + } + du.RawQuery = query.Encode() + return du.String() +} + +var _ model.Obj = (*File)(nil) +var _ model.Thumb = (*File)(nil) + +//func (f File) Thumb() string { +// +//} +//var _ model.Thumb = (*File)(nil) + +type Files struct { + //BaseResp + Data struct { + InfoList []File `json:"InfoList"` + Next string `json:"Next"` + } `json:"data"` +} + +//type DownResp struct { +// //BaseResp +// Data struct { +// DownloadUrl string `json:"DownloadUrl"` +// } `json:"data"` +//} diff --git a/drivers/123_share/util.go b/drivers/123_share/util.go new file mode 100644 index 0000000000000000000000000000000000000000..80ea8f0ca46a1d159fcfa492e403de3555c5610b --- /dev/null +++ b/drivers/123_share/util.go @@ -0,0 +1,117 @@ +package _123Share + +import ( + "context" + "errors" + "fmt" + "hash/crc32" + "math" + "math/rand" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + jsoniter "github.com/json-iterator/go" +) + +const ( + Api = "https://www.123pan.com/api" + AApi = "https://www.123pan.com/a/api" + BApi = "https://www.123pan.com/b/api" + MainApi = BApi + FileList = MainApi + "/share/get" + DownloadInfo = MainApi + "/share/download/info" + //AuthKeySalt = "8-8D$sL8gPjom7bk#cY" +) + +func signPath(path string, os string, version string) (k string, v string) { + table := []byte{'a', 'd', 'e', 'f', 'g', 'h', 'l', 'm', 'y', 'i', 'j', 'n', 'o', 'p', 'k', 'q', 'r', 's', 't', 'u', 'b', 'c', 'v', 'w', 's', 'z'} + random := fmt.Sprintf("%.f", math.Round(1e7*rand.Float64())) + now := time.Now().In(time.FixedZone("CST", 8*3600)) + timestamp := fmt.Sprint(now.Unix()) + nowStr := []byte(now.Format("200601021504")) + for i := 0; i < len(nowStr); i++ { + nowStr[i] = table[nowStr[i]-48] + } + timeSign := fmt.Sprint(crc32.ChecksumIEEE(nowStr)) + data := strings.Join([]string{timestamp, random, path, os, version, timeSign}, "|") + dataSign := fmt.Sprint(crc32.ChecksumIEEE([]byte(data))) + return timeSign, strings.Join([]string{timestamp, random, dataSign}, "-") +} + +func GetApi(rawUrl string) string { + u, _ := url.Parse(rawUrl) + query := u.Query() + query.Add(signPath(u.Path, "web", "3")) + u.RawQuery = query.Encode() + return u.String() +} + +func (d *Pan123Share) request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "origin": "https://www.123pan.com", + "referer": "https://www.123pan.com/", + "authorization": "Bearer " + d.AccessToken, + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) alist-client", + "platform": "web", + "app-version": "3", + //"user-agent": base.UserAgent, + }) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + res, err := req.Execute(method, GetApi(url)) + if err != nil { + return nil, err + } + body := res.Body() + code := utils.Json.Get(body, "code").ToInt() + if code != 0 { + return nil, errors.New(jsoniter.Get(body, "message").ToString()) + } + return body, nil +} + +func (d *Pan123Share) getFiles(ctx context.Context, parentId string) ([]File, error) { + page := 1 + res := make([]File, 0) + for { + if err := d.APIRateLimit(ctx, FileList); err != nil { + return nil, err + } + var resp Files + query := map[string]string{ + "limit": "100", + "next": "0", + "orderBy": "file_id", + "orderDirection": "desc", + "parentFileId": parentId, + "Page": strconv.Itoa(page), + "shareKey": d.ShareKey, + "SharePwd": d.SharePwd, + } + _, err := d.request(FileList, http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return nil, err + } + page++ + res = append(res, resp.Data.InfoList...) + if len(resp.Data.InfoList) == 0 || resp.Data.Next == "-1" { + break + } + } + return res, nil +} + +// do others that not defined in Driver interface diff --git a/drivers/139/driver.go b/drivers/139/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..42f13c20a86b7f0fdaa31ebc0d2ee898a81ec147 --- /dev/null +++ b/drivers/139/driver.go @@ -0,0 +1,653 @@ +package _139 + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/cron" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" +) + +type Yun139 struct { + model.Storage + Addition + cron *cron.Cron + Account string +} + +func (d *Yun139) Config() driver.Config { + return config +} + +func (d *Yun139) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Yun139) Init(ctx context.Context) error { + if d.Authorization == "" { + return fmt.Errorf("authorization is empty") + } + d.cron = cron.NewCron(time.Hour * 24 * 7) + d.cron.Do(func() { + err := d.refreshToken() + if err != nil { + log.Errorf("%+v", err) + } + }) + switch d.Addition.Type { + case MetaPersonalNew: + if len(d.Addition.RootFolderID) == 0 { + d.RootFolderID = "/" + } + return nil + case MetaPersonal: + if len(d.Addition.RootFolderID) == 0 { + d.RootFolderID = "root" + } + fallthrough + case MetaFamily: + decode, err := base64.StdEncoding.DecodeString(d.Authorization) + if err != nil { + return err + } + decodeStr := string(decode) + splits := strings.Split(decodeStr, ":") + if len(splits) < 2 { + return fmt.Errorf("authorization is invalid, splits < 2") + } + d.Account = splits[1] + _, err = d.post("/orchestration/personalCloud/user/v1.0/qryUserExternInfo", base.Json{ + "qryUserExternInfoReq": base.Json{ + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + }, + }, nil) + return err + default: + return errs.NotImplement + } +} + +func (d *Yun139) Drop(ctx context.Context) error { + if d.cron != nil { + d.cron.Stop() + } + return nil +} + +func (d *Yun139) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + switch d.Addition.Type { + case MetaPersonalNew: + return d.personalGetFiles(dir.GetID()) + case MetaPersonal: + return d.getFiles(dir.GetID()) + case MetaFamily: + return d.familyGetFiles(dir.GetID()) + default: + return nil, errs.NotImplement + } +} + +func (d *Yun139) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var url string + var err error + switch d.Addition.Type { + case MetaPersonalNew: + url, err = d.personalGetLink(file.GetID()) + case MetaPersonal: + fallthrough + case MetaFamily: + url, err = d.getLink(file.GetID()) + default: + return nil, errs.NotImplement + } + if err != nil { + return nil, err + } + return &model.Link{URL: url}, nil +} + +func (d *Yun139) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + var err error + switch d.Addition.Type { + case MetaPersonalNew: + data := base.Json{ + "parentFileId": parentDir.GetID(), + "name": dirName, + "description": "", + "type": "folder", + "fileRenameMode": "force_rename", + } + pathname := "/hcy/file/create" + _, err = d.personalPost(pathname, data, nil) + case MetaPersonal: + data := base.Json{ + "createCatalogExtReq": base.Json{ + "parentCatalogID": parentDir.GetID(), + "newCatalogName": dirName, + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + }, + } + pathname := "/orchestration/personalCloud/catalog/v1.0/createCatalogExt" + _, err = d.post(pathname, data, nil) + case MetaFamily: + cataID := parentDir.GetID() + path := cataID + data := base.Json{ + "cloudID": d.CloudID, + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + "docLibName": dirName, + "path": path, + } + pathname := "/orchestration/familyCloud-rebuild/cloudCatalog/v1.0/createCloudDoc" + _, err = d.post(pathname, data, nil) + default: + err = errs.NotImplement + } + return err +} + +func (d *Yun139) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + switch d.Addition.Type { + case MetaPersonalNew: + data := base.Json{ + "fileIds": []string{srcObj.GetID()}, + "toParentFileId": dstDir.GetID(), + } + pathname := "/hcy/file/batchMove" + _, err := d.personalPost(pathname, data, nil) + if err != nil { + return nil, err + } + return srcObj, nil + case MetaPersonal: + var contentInfoList []string + var catalogInfoList []string + if srcObj.IsDir() { + catalogInfoList = append(catalogInfoList, srcObj.GetID()) + } else { + contentInfoList = append(contentInfoList, srcObj.GetID()) + } + data := base.Json{ + "createBatchOprTaskReq": base.Json{ + "taskType": 3, + "actionType": "304", + "taskInfo": base.Json{ + "contentInfoList": contentInfoList, + "catalogInfoList": catalogInfoList, + "newCatalogID": dstDir.GetID(), + }, + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + }, + } + pathname := "/orchestration/personalCloud/batchOprTask/v1.0/createBatchOprTask" + _, err := d.post(pathname, data, nil) + if err != nil { + return nil, err + } + return srcObj, nil + default: + return nil, errs.NotImplement + } +} + +func (d *Yun139) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + var err error + switch d.Addition.Type { + case MetaPersonalNew: + data := base.Json{ + "fileId": srcObj.GetID(), + "name": newName, + "description": "", + } + pathname := "/hcy/file/update" + _, err = d.personalPost(pathname, data, nil) + case MetaPersonal: + var data base.Json + var pathname string + if srcObj.IsDir() { + data = base.Json{ + "catalogID": srcObj.GetID(), + "catalogName": newName, + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + } + pathname = "/orchestration/personalCloud/catalog/v1.0/updateCatalogInfo" + } else { + data = base.Json{ + "contentID": srcObj.GetID(), + "contentName": newName, + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + } + pathname = "/orchestration/personalCloud/content/v1.0/updateContentInfo" + } + _, err = d.post(pathname, data, nil) + default: + err = errs.NotImplement + } + return err +} + +func (d *Yun139) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + var err error + switch d.Addition.Type { + case MetaPersonalNew: + data := base.Json{ + "fileIds": []string{srcObj.GetID()}, + "toParentFileId": dstDir.GetID(), + } + pathname := "/hcy/file/batchCopy" + _, err := d.personalPost(pathname, data, nil) + return err + case MetaPersonal: + var contentInfoList []string + var catalogInfoList []string + if srcObj.IsDir() { + catalogInfoList = append(catalogInfoList, srcObj.GetID()) + } else { + contentInfoList = append(contentInfoList, srcObj.GetID()) + } + data := base.Json{ + "createBatchOprTaskReq": base.Json{ + "taskType": 3, + "actionType": 309, + "taskInfo": base.Json{ + "contentInfoList": contentInfoList, + "catalogInfoList": catalogInfoList, + "newCatalogID": dstDir.GetID(), + }, + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + }, + } + pathname := "/orchestration/personalCloud/batchOprTask/v1.0/createBatchOprTask" + _, err = d.post(pathname, data, nil) + default: + err = errs.NotImplement + } + return err +} + +func (d *Yun139) Remove(ctx context.Context, obj model.Obj) error { + switch d.Addition.Type { + case MetaPersonalNew: + data := base.Json{ + "fileIds": []string{obj.GetID()}, + } + pathname := "/hcy/recyclebin/batchTrash" + _, err := d.personalPost(pathname, data, nil) + return err + case MetaPersonal: + fallthrough + case MetaFamily: + return errs.NotImplement + log.Warn("==========================================") + var contentInfoList []string + var catalogInfoList []string + cataID := obj.GetID() + path := "" + if strings.Contains(cataID, "/") { + lastSlashIndex := strings.LastIndex(cataID, "/") + path = cataID[0:lastSlashIndex] + cataID = cataID[lastSlashIndex+1:] + } + + if obj.IsDir() { + catalogInfoList = append(catalogInfoList, cataID) + } else { + contentInfoList = append(contentInfoList, cataID) + } + data := base.Json{ + "createBatchOprTaskReq": base.Json{ + "taskType": 2, + "actionType": 201, + "taskInfo": base.Json{ + "newCatalogID": "", + "contentInfoList": contentInfoList, + "catalogInfoList": catalogInfoList, + }, + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + }, + } + pathname := "/orchestration/personalCloud/batchOprTask/v1.0/createBatchOprTask" + if d.isFamily() { + data = base.Json{ + "taskType": 2, + "sourceCloudID": d.CloudID, + "sourceCatalogType": 1002, + "path": path, + "contentList": catalogInfoList, + "catalogList": contentInfoList, + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + } + pathname = "/orchestration/familyCloud-rebuild/batchOprTask/v1.0/createBatchOprTask" + } + _, err := d.post(pathname, data, nil) + return err + default: + return errs.NotImplement + } +} + +const ( + _ = iota //ignore first value by assigning to blank identifier + KB = 1 << (10 * iota) + MB + GB + TB +) + +func getPartSize(size int64) int64 { + // 网盘对于分片数量存在上限 + if size/GB > 30 { + return 512 * MB + } + return 350 * MB +} + + + +func (d *Yun139) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + switch d.Addition.Type { + case MetaPersonalNew: + var err error + fullHash := stream.GetHash().GetHash(utils.SHA256) + if len(fullHash) <= 0 { + tmpF, err := stream.CacheFullInTempFile() + if err != nil { + return err + } + fullHash, err = utils.HashFile(utils.SHA256, tmpF) + if err != nil { + return err + } + } + + partInfos := []PartInfo{} + var partSize = getPartSize(stream.GetSize()) + part := (stream.GetSize() + partSize - 1) / partSize + if part == 0 { + part = 1 + } + for i := int64(0); i < part; i++ { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + + start := i * partSize + byteSize := stream.GetSize() - start + if byteSize > partSize { + byteSize = partSize + } + partNumber := i + 1 + partInfo := PartInfo{ + PartNumber: partNumber, + PartSize: byteSize, + ParallelHashCtx: ParallelHashCtx{ + PartOffset: start, + }, + } + partInfos = append(partInfos, partInfo) + } + + // return errs.NotImplement + data := base.Json{ + "contentHash": fullHash, + "contentHashAlgorithm": "SHA256", + "contentType": "application/octet-stream", + "parallelUpload": false, + "partInfos": partInfos, + "size": stream.GetSize(), + "parentFileId": dstDir.GetID(), + "name": stream.GetName(), + "type": "file", + "fileRenameMode": "auto_rename", + } + pathname := "/hcy/file/create" + var resp PersonalUploadResp + _, err = d.personalPost(pathname, data, &resp) + if err != nil { + return err + } + + if resp.Data.Exist || resp.Data.RapidUpload { + return nil + } + + // Progress + p := driver.NewProgress(stream.GetSize(), up) + + // Update Progress + // r := io.TeeReader(stream, p) + + for index, partInfo := range resp.Data.PartInfos { + + int64Index := int64(index) + start := int64Index * partSize + byteSize := stream.GetSize() - start + if byteSize > partSize { + byteSize = partSize + } + + retry := 2 // 只允许重试 2 次 + for attempt := 0; attempt <= retry; attempt++ { + limitReader := io.LimitReader(stream, byteSize) + // Update Progress + r := io.TeeReader(limitReader, p) + req, err := http.NewRequest("PUT", partInfo.UploadUrl, r) + if err != nil { + return err + } + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Content-Length", fmt.Sprint(byteSize)) + req.Header.Set("Origin", "https://yun.139.com") + req.Header.Set("Referer", "https://yun.139.com/") + req.ContentLength = byteSize + + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + + _ = res.Body.Close() + log.Debugf("%+v", res) + if res.StatusCode != http.StatusOK { + if res.StatusCode == http.StatusRequestTimeout && attempt < retry{ + log.Warn("服务器返回 408,尝试重试...") + continue + }else{ + return fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + } + break + } + } + + data = base.Json{ + "contentHash": fullHash, + "contentHashAlgorithm": "SHA256", + "fileId": resp.Data.FileId, + "uploadId": resp.Data.UploadId, + } + _, err = d.personalPost("/hcy/file/complete", data, nil) + if err != nil { + return err + } + return nil + case MetaPersonal: + fallthrough + case MetaFamily: + data := base.Json{ + "manualRename": 2, + "operation": 0, + "fileCount": 1, + "totalSize": 0, // 去除上传大小限制 + "uploadContentList": []base.Json{{ + "contentName": stream.GetName(), + "contentSize": stream.GetSize(), // 去除上传大小限制 + // "digest": "5a3231986ce7a6b46e408612d385bafa" + }}, + "parentCatalogID": dstDir.GetID(), + "newCatalogName": "", + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + } + pathname := "/orchestration/personalCloud/uploadAndDownload/v1.0/pcUploadFileRequest" + if d.isFamily() { + cataID := dstDir.GetID() + path := cataID + seqNo, _ := uuid.NewUUID() + data = base.Json{ + "cloudID": d.CloudID, + "path": path, + "operation": 0, + "cloudType": 1, + "catalogType": 3, + "manualRename": 2, + "fileCount": 1, + "totalSize": stream.GetSize(), + "uploadContentList": []base.Json{{ + "contentName": stream.GetName(), + "contentSize": stream.GetSize(), + }}, + "seqNo": seqNo, + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + } + pathname = "/orchestration/familyCloud-rebuild/content/v1.0/getFileUploadURL" + //return errs.NotImplement + } + var resp UploadResp + _, err := d.post(pathname, data, &resp) + if err != nil { + return err + } + // Progress + p := driver.NewProgress(stream.GetSize(), up) + + var partSize = getPartSize(stream.GetSize()) + //var partSize = stream.GetSize() + part := (stream.GetSize() + partSize - 1) / partSize + if part == 0 { + part = 1 + } + for i := int64(0); i < part; i++ { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + + start := i * partSize + byteSize := stream.GetSize() - start + if byteSize > partSize { + byteSize = partSize + } + + retry := 2 // 只允许重试 2次 + for attempt := 0; attempt <= retry; attempt++ { + limitReader := io.LimitReader(stream, byteSize) + // Update Progress + r := io.TeeReader(limitReader, p) + req, err := http.NewRequest("POST", resp.Data.UploadResult.RedirectionURL, r) + if err != nil { + return err + } + + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "text/plain;name="+unicode(stream.GetName())) + req.Header.Set("contentSize", strconv.FormatInt(stream.GetSize(), 10)) + req.Header.Set("range", fmt.Sprintf("bytes=%d-%d", start, start+byteSize-1)) + req.Header.Set("uploadtaskID", resp.Data.UploadResult.UploadTaskID) + req.Header.Set("rangeType", "0") + req.ContentLength = byteSize + + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + _ = res.Body.Close() + log.Debugf("%+v", res) + if res.StatusCode != http.StatusOK { + if res.StatusCode == http.StatusRequestTimeout && attempt < retry { + log.Warn("服务器返回 408,尝试重试...") + continue + }else{ + return fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + } + break + } + } + + return nil + default: + return errs.NotImplement + } +} + +func (d *Yun139) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { + switch d.Addition.Type { + case MetaPersonalNew: + var resp base.Json + var uri string + data := base.Json{ + "category": "video", + "fileId": args.Obj.GetID(), + } + switch args.Method { + case "video_preview": + uri = "/hcy/videoPreview/getPreviewInfo" + default: + return nil, errs.NotSupport + } + _, err := d.personalPost(uri, data, &resp) + if err != nil { + return nil, err + } + return resp["data"], nil + default: + return nil, errs.NotImplement + } +} + +var _ driver.Driver = (*Yun139)(nil) diff --git a/drivers/139/meta.go b/drivers/139/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..416e63a796cb636524145e53dab5bb4f69c7aeeb --- /dev/null +++ b/drivers/139/meta.go @@ -0,0 +1,25 @@ +package _139 + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + //Account string `json:"account" required:"true"` + Authorization string `json:"authorization" type:"text" required:"true"` + driver.RootID + Type string `json:"type" type:"select" options:"personal,family,personal_new" default:"personal"` + CloudID string `json:"cloud_id"` +} + +var config = driver.Config{ + Name: "139Yun", + LocalSort: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Yun139{} + }) +} diff --git a/drivers/139/types.go b/drivers/139/types.go new file mode 100644 index 0000000000000000000000000000000000000000..3cbf14dbaa4cc13323d0a526c77ca008b285e77f --- /dev/null +++ b/drivers/139/types.go @@ -0,0 +1,255 @@ +package _139 + +import ( + "encoding/xml" +) + +const ( + MetaPersonal string = "personal" + MetaFamily string = "family" + MetaPersonalNew string = "personal_new" +) + +type BaseResp struct { + Success bool `json:"success"` + Code string `json:"code"` + Message string `json:"message"` +} + +type Catalog struct { + CatalogID string `json:"catalogID"` + CatalogName string `json:"catalogName"` + //CatalogType int `json:"catalogType"` + CreateTime string `json:"createTime"` + UpdateTime string `json:"updateTime"` + //IsShared bool `json:"isShared"` + //CatalogLevel int `json:"catalogLevel"` + //ShareDoneeCount int `json:"shareDoneeCount"` + //OpenType int `json:"openType"` + //ParentCatalogID string `json:"parentCatalogId"` + //DirEtag int `json:"dirEtag"` + //Tombstoned int `json:"tombstoned"` + //ProxyID interface{} `json:"proxyID"` + //Moved int `json:"moved"` + //IsFixedDir int `json:"isFixedDir"` + //IsSynced interface{} `json:"isSynced"` + //Owner string `json:"owner"` + //Modifier interface{} `json:"modifier"` + //Path string `json:"path"` + //ShareType int `json:"shareType"` + //SoftLink interface{} `json:"softLink"` + //ExtProp1 interface{} `json:"extProp1"` + //ExtProp2 interface{} `json:"extProp2"` + //ExtProp3 interface{} `json:"extProp3"` + //ExtProp4 interface{} `json:"extProp4"` + //ExtProp5 interface{} `json:"extProp5"` + //ETagOprType int `json:"ETagOprType"` +} + +type Content struct { + ContentID string `json:"contentID"` + ContentName string `json:"contentName"` + //ContentSuffix string `json:"contentSuffix"` + ContentSize int64 `json:"contentSize"` + //ContentDesc string `json:"contentDesc"` + //ContentType int `json:"contentType"` + //ContentOrigin int `json:"contentOrigin"` + UpdateTime string `json:"updateTime"` + //CommentCount int `json:"commentCount"` + ThumbnailURL string `json:"thumbnailURL"` + //BigthumbnailURL string `json:"bigthumbnailURL"` + //PresentURL string `json:"presentURL"` + //PresentLURL string `json:"presentLURL"` + //PresentHURL string `json:"presentHURL"` + //ContentTAGList interface{} `json:"contentTAGList"` + //ShareDoneeCount int `json:"shareDoneeCount"` + //Safestate int `json:"safestate"` + //Transferstate int `json:"transferstate"` + //IsFocusContent int `json:"isFocusContent"` + //UpdateShareTime interface{} `json:"updateShareTime"` + //UploadTime string `json:"uploadTime"` + //OpenType int `json:"openType"` + //AuditResult int `json:"auditResult"` + //ParentCatalogID string `json:"parentCatalogId"` + //Channel string `json:"channel"` + //GeoLocFlag string `json:"geoLocFlag"` + Digest string `json:"digest"` + //Version string `json:"version"` + //FileEtag string `json:"fileEtag"` + //FileVersion string `json:"fileVersion"` + //Tombstoned int `json:"tombstoned"` + //ProxyID string `json:"proxyID"` + //Moved int `json:"moved"` + //MidthumbnailURL string `json:"midthumbnailURL"` + //Owner string `json:"owner"` + //Modifier string `json:"modifier"` + //ShareType int `json:"shareType"` + //ExtInfo struct { + // Uploader string `json:"uploader"` + // Address string `json:"address"` + //} `json:"extInfo"` + //Exif struct { + // CreateTime string `json:"createTime"` + // Longitude interface{} `json:"longitude"` + // Latitude interface{} `json:"latitude"` + // LocalSaveTime interface{} `json:"localSaveTime"` + //} `json:"exif"` + //CollectionFlag interface{} `json:"collectionFlag"` + //TreeInfo interface{} `json:"treeInfo"` + //IsShared bool `json:"isShared"` + //ETagOprType int `json:"ETagOprType"` +} + +type GetDiskResp struct { + BaseResp + Data struct { + Result struct { + ResultCode string `json:"resultCode"` + ResultDesc interface{} `json:"resultDesc"` + } `json:"result"` + GetDiskResult struct { + ParentCatalogID string `json:"parentCatalogID"` + NodeCount int `json:"nodeCount"` + CatalogList []Catalog `json:"catalogList"` + ContentList []Content `json:"contentList"` + IsCompleted int `json:"isCompleted"` + } `json:"getDiskResult"` + } `json:"data"` +} + +type UploadResp struct { + BaseResp + Data struct { + Result struct { + ResultCode string `json:"resultCode"` + ResultDesc interface{} `json:"resultDesc"` + } `json:"result"` + UploadResult struct { + UploadTaskID string `json:"uploadTaskID"` + RedirectionURL string `json:"redirectionUrl"` + NewContentIDList []struct { + ContentID string `json:"contentID"` + ContentName string `json:"contentName"` + IsNeedUpload string `json:"isNeedUpload"` + FileEtag int64 `json:"fileEtag"` + FileVersion int64 `json:"fileVersion"` + OverridenFlag int `json:"overridenFlag"` + } `json:"newContentIDList"` + CatalogIDList interface{} `json:"catalogIDList"` + IsSlice interface{} `json:"isSlice"` + } `json:"uploadResult"` + } `json:"data"` +} + +type CloudContent struct { + ContentID string `json:"contentID"` + //Modifier string `json:"modifier"` + //Nickname string `json:"nickname"` + //CloudNickName string `json:"cloudNickName"` + ContentName string `json:"contentName"` + //ContentType int `json:"contentType"` + //ContentSuffix string `json:"contentSuffix"` + ContentSize int64 `json:"contentSize"` + //ContentDesc string `json:"contentDesc"` + CreateTime string `json:"createTime"` + //Shottime interface{} `json:"shottime"` + LastUpdateTime string `json:"lastUpdateTime"` + ThumbnailURL string `json:"thumbnailURL"` + //MidthumbnailURL string `json:"midthumbnailURL"` + //BigthumbnailURL string `json:"bigthumbnailURL"` + //PresentURL string `json:"presentURL"` + //PresentLURL string `json:"presentLURL"` + //PresentHURL string `json:"presentHURL"` + //ParentCatalogID string `json:"parentCatalogID"` + //Uploader string `json:"uploader"` + //UploaderNickName string `json:"uploaderNickName"` + //TreeInfo interface{} `json:"treeInfo"` + //UpdateTime interface{} `json:"updateTime"` + //ExtInfo struct { + // Uploader string `json:"uploader"` + //} `json:"extInfo"` + //EtagOprType interface{} `json:"etagOprType"` +} + +type CloudCatalog struct { + CatalogID string `json:"catalogID"` + CatalogName string `json:"catalogName"` + //CloudID string `json:"cloudID"` + CreateTime string `json:"createTime"` + LastUpdateTime string `json:"lastUpdateTime"` + //Creator string `json:"creator"` + //CreatorNickname string `json:"creatorNickname"` +} + +type QueryContentListResp struct { + BaseResp + Data struct { + Result struct { + ResultCode string `json:"resultCode"` + ResultDesc string `json:"resultDesc"` + } `json:"result"` + Path string `json:"path"` + CloudContentList []CloudContent `json:"cloudContentList"` + CloudCatalogList []CloudCatalog `json:"cloudCatalogList"` + TotalCount int `json:"totalCount"` + RecallContent interface{} `json:"recallContent"` + } `json:"data"` +} + +type PartInfo struct { + PartNumber int64 `json:"partNumber"` + PartSize int64 `json:"partSize"` + ParallelHashCtx ParallelHashCtx `json:"parallelHashCtx"` +} + +type ParallelHashCtx struct { + PartOffset int64 `json:"partOffset"` +} + +type PersonalThumbnail struct { + Style string `json:"style"` + Url string `json:"url"` +} + +type PersonalFileItem struct { + FileId string `json:"fileId"` + Name string `json:"name"` + Size int64 `json:"size"` + Type string `json:"type"` + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + Thumbnails []PersonalThumbnail `json:"thumbnailUrls"` +} + +type PersonalListResp struct { + BaseResp + Data struct { + Items []PersonalFileItem `json:"items"` + NextPageCursor string `json:"nextPageCursor"` + } +} + +type PersonalPartInfo struct { + PartNumber int `json:"partNumber"` + UploadUrl string `json:"uploadUrl"` +} + +type PersonalUploadResp struct { + BaseResp + Data struct { + FileId string `json:"fileId"` + PartInfos []PersonalPartInfo `json:"partInfos"` + Exist bool `json:"exist"` + RapidUpload bool `json:"rapidUpload"` + UploadId string `json:"uploadId"` + } +} + +type RefreshTokenResp struct { + XMLName xml.Name `xml:"root"` + Return string `xml:"return"` + Token string `xml:"token"` + Expiretime int32 `xml:"expiretime"` + AccessToken string `xml:"accessToken"` + Desc string `xml:"desc"` +} diff --git a/drivers/139/util.go b/drivers/139/util.go new file mode 100644 index 0000000000000000000000000000000000000000..404b54b5bd8d7af81f94d6de7ca85b5c0c72b9ff --- /dev/null +++ b/drivers/139/util.go @@ -0,0 +1,438 @@ +package _139 + +import ( + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/url" + "sort" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/pkg/utils/random" + "github.com/go-resty/resty/v2" + jsoniter "github.com/json-iterator/go" + log "github.com/sirupsen/logrus" +) + +// do others that not defined in Driver interface +func (d *Yun139) isFamily() bool { + return d.Type == "family" +} + +func encodeURIComponent(str string) string { + r := url.QueryEscape(str) + r = strings.Replace(r, "+", "%20", -1) + r = strings.Replace(r, "%21", "!", -1) + r = strings.Replace(r, "%27", "'", -1) + r = strings.Replace(r, "%28", "(", -1) + r = strings.Replace(r, "%29", ")", -1) + r = strings.Replace(r, "%2A", "*", -1) + return r +} + +func calSign(body, ts, randStr string) string { + body = encodeURIComponent(body) + strs := strings.Split(body, "") + sort.Strings(strs) + body = strings.Join(strs, "") + body = base64.StdEncoding.EncodeToString([]byte(body)) + res := utils.GetMD5EncodeStr(body) + utils.GetMD5EncodeStr(ts+":"+randStr) + res = strings.ToUpper(utils.GetMD5EncodeStr(res)) + return res +} + +func getTime(t string) time.Time { + stamp, _ := time.ParseInLocation("20060102150405", t, utils.CNLoc) + return stamp +} + +func (d *Yun139) refreshToken() error { + url := "https://aas.caiyun.feixin.10086.cn:443/tellin/authTokenRefresh.do" + var resp RefreshTokenResp + decode, err := base64.StdEncoding.DecodeString(d.Authorization) + if err != nil { + return err + } + decodeStr := string(decode) + splits := strings.Split(decodeStr, ":") + reqBody := "" + splits[2] + "" + splits[1] + "656" + _, err = base.RestyClient.R(). + ForceContentType("application/xml"). + SetBody(reqBody). + SetResult(&resp). + Post(url) + if err != nil { + return err + } + if resp.Return != "0" { + return fmt.Errorf("failed to refresh token: %s", resp.Desc) + } + d.Authorization = base64.StdEncoding.EncodeToString([]byte(splits[0] + ":" + splits[1] + ":" + resp.Token)) + op.MustSaveDriverStorage(d) + return nil +} + +func (d *Yun139) request(pathname string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + url := "https://yun.139.com" + pathname + req := base.RestyClient.R() + randStr := random.String(16) + ts := time.Now().Format("2006-01-02 15:04:05") + if callback != nil { + callback(req) + } + body, err := utils.Json.Marshal(req.Body) + if err != nil { + return nil, err + } + sign := calSign(string(body), ts, randStr) + svcType := "1" + if d.isFamily() { + svcType = "2" + } + req.SetHeaders(map[string]string{ + "Accept": "application/json, text/plain, */*", + "CMS-DEVICE": "default", + "Authorization": "Basic " + d.Authorization, + "mcloud-channel": "1000101", + "mcloud-client": "10701", + //"mcloud-route": "001", + "mcloud-sign": fmt.Sprintf("%s,%s,%s", ts, randStr, sign), + //"mcloud-skey":"", + "mcloud-version": "6.6.0", + "Origin": "https://yun.139.com", + "Referer": "https://yun.139.com/w/", + "x-DeviceInfo": "||9|6.6.0|chrome|95.0.4638.69|uwIy75obnsRPIwlJSd7D9GhUvFwG96ce||macos 10.15.2||zh-CN|||", + "x-huawei-channelSrc": "10000034", + "x-inner-ntwk": "2", + "x-m4c-caller": "PC", + "x-m4c-src": "10002", + "x-SvcType": svcType, + }) + + var e BaseResp + req.SetResult(&e) + res, err := req.Execute(method, url) + log.Debugln(res.String()) + if !e.Success { + return nil, errors.New(e.Message) + } + if resp != nil { + err = utils.Json.Unmarshal(res.Body(), resp) + if err != nil { + return nil, err + } + } + return res.Body(), nil +} +func (d *Yun139) post(pathname string, data interface{}, resp interface{}) ([]byte, error) { + return d.request(pathname, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, resp) +} + +func (d *Yun139) getFiles(catalogID string) ([]model.Obj, error) { + start := 0 + limit := 100 + files := make([]model.Obj, 0) + for { + data := base.Json{ + "catalogID": catalogID, + "sortDirection": 1, + "startNumber": start + 1, + "endNumber": start + limit, + "filterType": 0, + "catalogSortType": 0, + "contentSortType": 0, + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + } + var resp GetDiskResp + _, err := d.post("/orchestration/personalCloud/catalog/v1.0/getDisk", data, &resp) + if err != nil { + return nil, err + } + for _, catalog := range resp.Data.GetDiskResult.CatalogList { + f := model.Object{ + ID: catalog.CatalogID, + Name: catalog.CatalogName, + Size: 0, + Modified: getTime(catalog.UpdateTime), + Ctime: getTime(catalog.CreateTime), + IsFolder: true, + } + files = append(files, &f) + } + for _, content := range resp.Data.GetDiskResult.ContentList { + f := model.ObjThumb{ + Object: model.Object{ + ID: content.ContentID, + Name: content.ContentName, + Size: content.ContentSize, + Modified: getTime(content.UpdateTime), + HashInfo: utils.NewHashInfo(utils.MD5, content.Digest), + }, + Thumbnail: model.Thumbnail{Thumbnail: content.ThumbnailURL}, + //Thumbnail: content.BigthumbnailURL, + } + files = append(files, &f) + } + if start+limit >= resp.Data.GetDiskResult.NodeCount { + break + } + start += limit + } + return files, nil +} + +func (d *Yun139) newJson(data map[string]interface{}) base.Json { + common := map[string]interface{}{ + "catalogType": 3, + "cloudID": d.CloudID, + "cloudType": 1, + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + } + return utils.MergeMap(data, common) +} + +func (d *Yun139) familyGetFiles(catalogID string) ([]model.Obj, error) { + + if strings.Contains(catalogID, "/") { + lastSlashIndex := strings.LastIndex(catalogID, "/") + catalogID = catalogID[lastSlashIndex+1:] + } + + pageNum := 1 + files := make([]model.Obj, 0) + for { + data := d.newJson(base.Json{ + "catalogID": catalogID, + "contentSortType": 0, + "pageInfo": base.Json{ + "pageNum": pageNum, + "pageSize": 100, + }, + "sortDirection": 1, + }) + var resp QueryContentListResp + _, err := d.post("/orchestration/familyCloud/content/v1.0/queryContentList", data, &resp) + if err != nil { + return nil, err + } + for _, catalog := range resp.Data.CloudCatalogList { + f := model.Object{ + ID: resp.Data.Path + "/" + catalog.CatalogID, + Name: catalog.CatalogName, + Size: 0, + IsFolder: true, + Modified: getTime(catalog.LastUpdateTime), + Ctime: getTime(catalog.CreateTime), + } + files = append(files, &f) + } + for _, content := range resp.Data.CloudContentList { + f := model.ObjThumb{ + Object: model.Object{ + ID: content.ContentID, + Name: content.ContentName, + Size: content.ContentSize, + Modified: getTime(content.LastUpdateTime), + Ctime: getTime(content.CreateTime), + }, + Thumbnail: model.Thumbnail{Thumbnail: content.ThumbnailURL}, + //Thumbnail: content.BigthumbnailURL, + } + files = append(files, &f) + } + if 100*pageNum > resp.Data.TotalCount { + break + } + pageNum++ + } + return files, nil +} + +func (d *Yun139) getLink(contentId string) (string, error) { + data := base.Json{ + "appName": "", + "contentID": contentId, + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + } + res, err := d.post("/orchestration/personalCloud/uploadAndDownload/v1.0/downloadRequest", + data, nil) + if err != nil { + return "", err + } + return jsoniter.Get(res, "data", "downloadURL").ToString(), nil +} + +func unicode(str string) string { + textQuoted := strconv.QuoteToASCII(str) + textUnquoted := textQuoted[1 : len(textQuoted)-1] + return textUnquoted +} + +func (d *Yun139) personalRequest(pathname string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + url := "https://personal-kd-njs.yun.139.com" + pathname + req := base.RestyClient.R() + randStr := random.String(16) + ts := time.Now().Format("2006-01-02 15:04:05") + if callback != nil { + callback(req) + } + body, err := utils.Json.Marshal(req.Body) + if err != nil { + return nil, err + } + sign := calSign(string(body), ts, randStr) + svcType := "1" + if d.isFamily() { + svcType = "2" + } + req.SetHeaders(map[string]string{ + "Accept": "application/json, text/plain, */*", + "Authorization": "Basic " + d.Authorization, + "Caller": "web", + "Cms-Device": "default", + "Mcloud-Channel": "1000101", + "Mcloud-Client": "10701", + "Mcloud-Route": "001", + "Mcloud-Sign": fmt.Sprintf("%s,%s,%s", ts, randStr, sign), + "Mcloud-Version": "7.13.0", + "Origin": "https://yun.139.com", + "Referer": "https://yun.139.com/w/", + "x-DeviceInfo": "||9|7.13.0|chrome|120.0.0.0|||windows 10||zh-CN|||", + "x-huawei-channelSrc": "10000034", + "x-inner-ntwk": "2", + "x-m4c-caller": "PC", + "x-m4c-src": "10002", + "x-SvcType": svcType, + "X-Yun-Api-Version": "v1", + "X-Yun-App-Channel": "10000034", + "X-Yun-Channel-Source": "10000034", + "X-Yun-Client-Info": "||9|7.13.0|chrome|120.0.0.0|||windows 10||zh-CN|||dW5kZWZpbmVk||", + "X-Yun-Module-Type": "100", + "X-Yun-Svc-Type": "1", + }) + + var e BaseResp + req.SetResult(&e) + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + log.Debugln(res.String()) + if !e.Success { + return nil, errors.New(e.Message) + } + if resp != nil { + err = utils.Json.Unmarshal(res.Body(), resp) + if err != nil { + return nil, err + } + } + return res.Body(), nil +} +func (d *Yun139) personalPost(pathname string, data interface{}, resp interface{}) ([]byte, error) { + return d.personalRequest(pathname, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, resp) +} + +func getPersonalTime(t string) time.Time { + stamp, err := time.ParseInLocation("2006-01-02T15:04:05.999-07:00", t, utils.CNLoc) + if err != nil { + panic(err) + } + return stamp +} + +func (d *Yun139) personalGetFiles(fileId string) ([]model.Obj, error) { + files := make([]model.Obj, 0) + nextPageCursor := "" + for { + data := base.Json{ + "imageThumbnailStyleList": []string{"Small", "Large"}, + "orderBy": "updated_at", + "orderDirection": "DESC", + "pageInfo": base.Json{ + "pageCursor": nextPageCursor, + "pageSize": 100, + }, + "parentFileId": fileId, + } + var resp PersonalListResp + _, err := d.personalPost("/hcy/file/list", data, &resp) + if err != nil { + return nil, err + } + nextPageCursor = resp.Data.NextPageCursor + for _, item := range resp.Data.Items { + var isFolder = (item.Type == "folder") + var f model.Obj + if isFolder { + f = &model.Object{ + ID: item.FileId, + Name: item.Name, + Size: 0, + Modified: getPersonalTime(item.UpdatedAt), + Ctime: getPersonalTime(item.CreatedAt), + IsFolder: isFolder, + } + } else { + var Thumbnails = item.Thumbnails + var ThumbnailUrl string + if len(Thumbnails) > 0 { + ThumbnailUrl = Thumbnails[len(Thumbnails)-1].Url + } + f = &model.ObjThumb{ + Object: model.Object{ + ID: item.FileId, + Name: item.Name, + Size: item.Size, + Modified: getPersonalTime(item.UpdatedAt), + Ctime: getPersonalTime(item.CreatedAt), + IsFolder: isFolder, + }, + Thumbnail: model.Thumbnail{Thumbnail: ThumbnailUrl}, + } + } + files = append(files, f) + } + if len(nextPageCursor) == 0 { + break + } + } + return files, nil +} + +func (d *Yun139) personalGetLink(fileId string) (string, error) { + data := base.Json{ + "fileId": fileId, + } + res, err := d.personalPost("/hcy/file/getDownloadUrl", + data, nil) + if err != nil { + return "", err + } + var cdnUrl = jsoniter.Get(res, "data", "cdnUrl").ToString() + if cdnUrl != "" { + return cdnUrl, nil + } else { + return jsoniter.Get(res, "data", "url").ToString(), nil + } +} diff --git a/drivers/189/driver.go b/drivers/189/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..6fc4932640c9d6bc201d8a0fa1c384a357ee854d --- /dev/null +++ b/drivers/189/driver.go @@ -0,0 +1,197 @@ +package _189 + +import ( + "context" + "net/http" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +type Cloud189 struct { + model.Storage + Addition + client *resty.Client + rsa Rsa + sessionKey string +} + +func (d *Cloud189) Config() driver.Config { + return config +} + +func (d *Cloud189) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Cloud189) Init(ctx context.Context) error { + d.client = base.NewRestyClient(). + SetHeader("Referer", "https://cloud.189.cn/") + return d.newLogin() +} + +func (d *Cloud189) Drop(ctx context.Context) error { + return nil +} + +func (d *Cloud189) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + return d.getFiles(dir.GetID()) +} + +func (d *Cloud189) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var resp DownResp + u := "https://cloud.189.cn/api/portal/getFileInfo.action" + _, err := d.request(u, http.MethodGet, func(req *resty.Request) { + req.SetQueryParam("fileId", file.GetID()) + }, &resp) + if err != nil { + return nil, err + } + client := resty.NewWithClient(d.client.GetClient()).SetRedirectPolicy( + resty.RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + })) + res, err := client.R().SetHeader("User-Agent", base.UserAgent).Get("https:" + resp.FileDownloadUrl) + if err != nil { + return nil, err + } + log.Debugln(res.Status()) + log.Debugln(res.String()) + link := model.Link{} + log.Debugln("first url:", resp.FileDownloadUrl) + if res.StatusCode() == 302 { + link.URL = res.Header().Get("location") + log.Debugln("second url:", link.URL) + _, _ = client.R().Get(link.URL) + if res.StatusCode() == 302 { + link.URL = res.Header().Get("location") + } + log.Debugln("third url:", link.URL) + } else { + link.URL = resp.FileDownloadUrl + } + link.URL = strings.Replace(link.URL, "http://", "https://", 1) + return &link, nil +} + +func (d *Cloud189) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + form := map[string]string{ + "parentFolderId": parentDir.GetID(), + "folderName": dirName, + } + _, err := d.request("https://cloud.189.cn/api/open/file/createFolder.action", http.MethodPost, func(req *resty.Request) { + req.SetFormData(form) + }, nil) + return err +} + +func (d *Cloud189) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + isFolder := 0 + if srcObj.IsDir() { + isFolder = 1 + } + taskInfos := []base.Json{ + { + "fileId": srcObj.GetID(), + "fileName": srcObj.GetName(), + "isFolder": isFolder, + }, + } + taskInfosBytes, err := utils.Json.Marshal(taskInfos) + if err != nil { + return err + } + form := map[string]string{ + "type": "MOVE", + "targetFolderId": dstDir.GetID(), + "taskInfos": string(taskInfosBytes), + } + _, err = d.request("https://cloud.189.cn/api/open/batch/createBatchTask.action", http.MethodPost, func(req *resty.Request) { + req.SetFormData(form) + }, nil) + return err +} + +func (d *Cloud189) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + url := "https://cloud.189.cn/api/open/file/renameFile.action" + idKey := "fileId" + nameKey := "destFileName" + if srcObj.IsDir() { + url = "https://cloud.189.cn/api/open/file/renameFolder.action" + idKey = "folderId" + nameKey = "destFolderName" + } + form := map[string]string{ + idKey: srcObj.GetID(), + nameKey: newName, + } + _, err := d.request(url, http.MethodPost, func(req *resty.Request) { + req.SetFormData(form) + }, nil) + return err +} + +func (d *Cloud189) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + isFolder := 0 + if srcObj.IsDir() { + isFolder = 1 + } + taskInfos := []base.Json{ + { + "fileId": srcObj.GetID(), + "fileName": srcObj.GetName(), + "isFolder": isFolder, + }, + } + taskInfosBytes, err := utils.Json.Marshal(taskInfos) + if err != nil { + return err + } + form := map[string]string{ + "type": "COPY", + "targetFolderId": dstDir.GetID(), + "taskInfos": string(taskInfosBytes), + } + _, err = d.request("https://cloud.189.cn/api/open/batch/createBatchTask.action", http.MethodPost, func(req *resty.Request) { + req.SetFormData(form) + }, nil) + return err +} + +func (d *Cloud189) Remove(ctx context.Context, obj model.Obj) error { + isFolder := 0 + if obj.IsDir() { + isFolder = 1 + } + taskInfos := []base.Json{ + { + "fileId": obj.GetID(), + "fileName": obj.GetName(), + "isFolder": isFolder, + }, + } + taskInfosBytes, err := utils.Json.Marshal(taskInfos) + if err != nil { + return err + } + form := map[string]string{ + "type": "DELETE", + "targetFolderId": "", + "taskInfos": string(taskInfosBytes), + } + _, err = d.request("https://cloud.189.cn/api/open/batch/createBatchTask.action", http.MethodPost, func(req *resty.Request) { + req.SetFormData(form) + }, nil) + return err +} + +func (d *Cloud189) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + return d.newUpload(ctx, dstDir, stream, up) +} + +var _ driver.Driver = (*Cloud189)(nil) diff --git a/drivers/189/help.go b/drivers/189/help.go new file mode 100644 index 0000000000000000000000000000000000000000..a86108e55fca9ee183ed4f1ba850acf2f70f0861 --- /dev/null +++ b/drivers/189/help.go @@ -0,0 +1,186 @@ +package _189 + +import ( + "bytes" + "crypto/aes" + "crypto/hmac" + "crypto/md5" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "fmt" + "net/url" + "regexp" + "strconv" + "strings" + + myrand "github.com/alist-org/alist/v3/pkg/utils/random" + log "github.com/sirupsen/logrus" +) + +func random() string { + return fmt.Sprintf("0.%17v", myrand.Rand.Int63n(100000000000000000)) +} + +func RsaEncode(origData []byte, j_rsakey string, hex bool) string { + publicKey := []byte("-----BEGIN PUBLIC KEY-----\n" + j_rsakey + "\n-----END PUBLIC KEY-----") + block, _ := pem.Decode(publicKey) + pubInterface, _ := x509.ParsePKIXPublicKey(block.Bytes) + pub := pubInterface.(*rsa.PublicKey) + b, err := rsa.EncryptPKCS1v15(rand.Reader, pub, origData) + if err != nil { + log.Errorf("err: %s", err.Error()) + } + res := base64.StdEncoding.EncodeToString(b) + if hex { + return b64tohex(res) + } + return res +} + +var b64map = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + +var BI_RM = "0123456789abcdefghijklmnopqrstuvwxyz" + +func int2char(a int) string { + return strings.Split(BI_RM, "")[a] +} + +func b64tohex(a string) string { + d := "" + e := 0 + c := 0 + for i := 0; i < len(a); i++ { + m := strings.Split(a, "")[i] + if m != "=" { + v := strings.Index(b64map, m) + if 0 == e { + e = 1 + d += int2char(v >> 2) + c = 3 & v + } else if 1 == e { + e = 2 + d += int2char(c<<2 | v>>4) + c = 15 & v + } else if 2 == e { + e = 3 + d += int2char(c) + d += int2char(v >> 2) + c = 3 & v + } else { + e = 0 + d += int2char(c<<2 | v>>4) + d += int2char(15 & v) + } + } + } + if e == 1 { + d += int2char(c << 2) + } + return d +} + +func qs(form map[string]string) string { + f := make(url.Values) + for k, v := range form { + f.Set(k, v) + } + return EncodeParam(f) + //strList := make([]string, 0) + //for k, v := range form { + // strList = append(strList, fmt.Sprintf("%s=%s", k, url.QueryEscape(v))) + //} + //return strings.Join(strList, "&") +} + +func EncodeParam(v url.Values) string { + if v == nil { + return "" + } + var buf strings.Builder + keys := make([]string, 0, len(v)) + for k := range v { + keys = append(keys, k) + } + for _, k := range keys { + vs := v[k] + for _, v := range vs { + if buf.Len() > 0 { + buf.WriteByte('&') + } + buf.WriteString(k) + buf.WriteByte('=') + //if k == "fileName" { + // buf.WriteString(encode(v)) + //} else { + buf.WriteString(v) + //} + } + } + return buf.String() +} + +func encode(str string) string { + //str = strings.ReplaceAll(str, "%", "%25") + //str = strings.ReplaceAll(str, "&", "%26") + //str = strings.ReplaceAll(str, "+", "%2B") + //return str + return url.QueryEscape(str) +} + +func AesEncrypt(data, key []byte) []byte { + block, _ := aes.NewCipher(key) + if block == nil { + return []byte{} + } + data = PKCS7Padding(data, block.BlockSize()) + decrypted := make([]byte, len(data)) + size := block.BlockSize() + for bs, be := 0, size; bs < len(data); bs, be = bs+size, be+size { + block.Encrypt(decrypted[bs:be], data[bs:be]) + } + return decrypted +} + +func PKCS7Padding(ciphertext []byte, blockSize int) []byte { + padding := blockSize - len(ciphertext)%blockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(ciphertext, padtext...) +} + +func hmacSha1(data string, secret string) string { + h := hmac.New(sha1.New, []byte(secret)) + h.Write([]byte(data)) + return hex.EncodeToString(h.Sum(nil)) +} + +func getMd5(data []byte) []byte { + h := md5.New() + h.Write(data) + return h.Sum(nil) +} + +func decodeURIComponent(str string) string { + r, _ := url.PathUnescape(str) + //r = strings.ReplaceAll(r, " ", "+") + return r +} + +func Random(v string) string { + reg := regexp.MustCompilePOSIX("[xy]") + data := reg.ReplaceAllFunc([]byte(v), func(msg []byte) []byte { + var i int64 + t := int64(16 * myrand.Rand.Float32()) + if msg[0] == 120 { + i = t + } else { + i = 3&t | 8 + } + return []byte(strconv.FormatInt(i, 16)) + }) + return string(data) +} diff --git a/drivers/189/login.go b/drivers/189/login.go new file mode 100644 index 0000000000000000000000000000000000000000..0fcec19aefb60cf0dbe74a753db3f7de18c75110 --- /dev/null +++ b/drivers/189/login.go @@ -0,0 +1,126 @@ +package _189 + +import ( + "errors" + "strconv" + + "github.com/alist-org/alist/v3/pkg/utils" + log "github.com/sirupsen/logrus" +) + +type AppConf struct { + Data struct { + AccountType string `json:"accountType"` + AgreementCheck string `json:"agreementCheck"` + AppKey string `json:"appKey"` + ClientType int `json:"clientType"` + IsOauth2 bool `json:"isOauth2"` + LoginSort string `json:"loginSort"` + MailSuffix string `json:"mailSuffix"` + PageKey string `json:"pageKey"` + ParamId string `json:"paramId"` + RegReturnUrl string `json:"regReturnUrl"` + ReqId string `json:"reqId"` + ReturnUrl string `json:"returnUrl"` + ShowFeedback string `json:"showFeedback"` + ShowPwSaveName string `json:"showPwSaveName"` + ShowQrSaveName string `json:"showQrSaveName"` + ShowSmsSaveName string `json:"showSmsSaveName"` + Sso string `json:"sso"` + } `json:"data"` + Msg string `json:"msg"` + Result string `json:"result"` +} + +type EncryptConf struct { + Result int `json:"result"` + Data struct { + UpSmsOn string `json:"upSmsOn"` + Pre string `json:"pre"` + PreDomain string `json:"preDomain"` + PubKey string `json:"pubKey"` + } `json:"data"` +} + +func (d *Cloud189) newLogin() error { + url := "https://cloud.189.cn/api/portal/loginUrl.action?redirectURL=https%3A%2F%2Fcloud.189.cn%2Fmain.action" + res, err := d.client.R().Get(url) + if err != nil { + return err + } + // Is logged in + redirectURL := res.RawResponse.Request.URL + if redirectURL.String() == "https://cloud.189.cn/web/main" { + return nil + } + lt := redirectURL.Query().Get("lt") + reqId := redirectURL.Query().Get("reqId") + appId := redirectURL.Query().Get("appId") + headers := map[string]string{ + "lt": lt, + "reqid": reqId, + "referer": redirectURL.String(), + "origin": "https://open.e.189.cn", + } + // get app Conf + var appConf AppConf + res, err = d.client.R().SetHeaders(headers).SetFormData(map[string]string{ + "version": "2.0", + "appKey": appId, + }).SetResult(&appConf).Post("https://open.e.189.cn/api/logbox/oauth2/appConf.do") + if err != nil { + return err + } + log.Debugf("189 AppConf resp body: %s", res.String()) + if appConf.Result != "0" { + return errors.New(appConf.Msg) + } + // get encrypt conf + var encryptConf EncryptConf + res, err = d.client.R().SetHeaders(headers).SetFormData(map[string]string{ + "appId": appId, + }).Post("https://open.e.189.cn/api/logbox/config/encryptConf.do") + if err != nil { + return err + } + err = utils.Json.Unmarshal(res.Body(), &encryptConf) + if err != nil { + return err + } + log.Debugf("189 EncryptConf resp body: %s\n%+v", res.String(), encryptConf) + if encryptConf.Result != 0 { + return errors.New("get EncryptConf error:" + res.String()) + } + // TODO: getUUID? needcaptcha + // login + loginData := map[string]string{ + "version": "v2.0", + "apToken": "", + "appKey": appId, + "accountType": appConf.Data.AccountType, + "userName": encryptConf.Data.Pre + RsaEncode([]byte(d.Username), encryptConf.Data.PubKey, true), + "epd": encryptConf.Data.Pre + RsaEncode([]byte(d.Password), encryptConf.Data.PubKey, true), + "captchaType": "", + "validateCode": "", + "smsValidateCode": "", + "captchaToken": "", + "returnUrl": appConf.Data.ReturnUrl, + "mailSuffix": appConf.Data.MailSuffix, + "dynamicCheck": "FALSE", + "clientType": strconv.Itoa(appConf.Data.ClientType), + "cb_SaveName": "3", + "isOauth2": strconv.FormatBool(appConf.Data.IsOauth2), + "state": "", + "paramId": appConf.Data.ParamId, + } + res, err = d.client.R().SetHeaders(headers).SetFormData(loginData).Post("https://open.e.189.cn/api/logbox/oauth2/loginSubmit.do") + if err != nil { + return err + } + log.Debugf("189 login resp body: %s", res.String()) + loginResult := utils.Json.Get(res.Body(), "result").ToInt() + if loginResult != 0 { + return errors.New(utils.Json.Get(res.Body(), "msg").ToString()) + } + return nil +} diff --git a/drivers/189/meta.go b/drivers/189/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..ad621fb440d25b733c6d3ab857a083d37a0e8251 --- /dev/null +++ b/drivers/189/meta.go @@ -0,0 +1,26 @@ +package _189 + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + Username string `json:"username" required:"true"` + Password string `json:"password" required:"true"` + Cookie string `json:"cookie" help:"Fill in the cookie if need captcha"` + driver.RootID +} + +var config = driver.Config{ + Name: "189Cloud", + LocalSort: true, + DefaultRoot: "-11", + Alert: `info|You can try to use 189PC driver if this driver does not work.`, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Cloud189{} + }) +} diff --git a/drivers/189/types.go b/drivers/189/types.go new file mode 100644 index 0000000000000000000000000000000000000000..5354db953b59f7e6fb2d47b6363c61b810712b44 --- /dev/null +++ b/drivers/189/types.go @@ -0,0 +1,68 @@ +package _189 + +type LoginResp struct { + Msg string `json:"msg"` + Result int `json:"result"` + ToUrl string `json:"toUrl"` +} + +type Error struct { + ErrorCode string `json:"errorCode"` + ErrorMsg string `json:"errorMsg"` +} + +type File struct { + Id int64 `json:"id"` + LastOpTime string `json:"lastOpTime"` + Name string `json:"name"` + Size int64 `json:"size"` + Icon struct { + SmallUrl string `json:"smallUrl"` + //LargeUrl string `json:"largeUrl"` + } `json:"icon"` + Url string `json:"url"` +} + +type Folder struct { + Id int64 `json:"id"` + LastOpTime string `json:"lastOpTime"` + Name string `json:"name"` +} + +type Files struct { + ResCode int `json:"res_code"` + ResMessage string `json:"res_message"` + FileListAO struct { + Count int `json:"count"` + FileList []File `json:"fileList"` + FolderList []Folder `json:"folderList"` + } `json:"fileListAO"` +} + +type UploadUrlsResp struct { + Code string `json:"code"` + UploadUrls map[string]Part `json:"uploadUrls"` +} + +type Part struct { + RequestURL string `json:"requestURL"` + RequestHeader string `json:"requestHeader"` +} + +type Rsa struct { + Expire int64 `json:"expire"` + PkId string `json:"pkId"` + PubKey string `json:"pubKey"` +} + +type Down struct { + ResCode int `json:"res_code"` + ResMessage string `json:"res_message"` + FileDownloadUrl string `json:"fileDownloadUrl"` +} + +type DownResp struct { + ResCode int `json:"res_code"` + ResMessage string `json:"res_message"` + FileDownloadUrl string `json:"downloadUrl"` +} diff --git a/drivers/189/util.go b/drivers/189/util.go new file mode 100644 index 0000000000000000000000000000000000000000..0b4c0633d7b14ccff68572736dcb03ffa6607edd --- /dev/null +++ b/drivers/189/util.go @@ -0,0 +1,398 @@ +package _189 + +import ( + "bytes" + "context" + "crypto/md5" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "io" + "math" + "net/http" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + myrand "github.com/alist-org/alist/v3/pkg/utils/random" + "github.com/go-resty/resty/v2" + jsoniter "github.com/json-iterator/go" + log "github.com/sirupsen/logrus" +) + +// do others that not defined in Driver interface + +//func (d *Cloud189) login() error { +// url := "https://cloud.189.cn/api/portal/loginUrl.action?redirectURL=https%3A%2F%2Fcloud.189.cn%2Fmain.action" +// b := "" +// lt := "" +// ltText := regexp.MustCompile(`lt = "(.+?)"`) +// var res *resty.Response +// var err error +// for i := 0; i < 3; i++ { +// res, err = d.client.R().Get(url) +// if err != nil { +// return err +// } +// // 已经登陆 +// if res.RawResponse.Request.URL.String() == "https://cloud.189.cn/web/main" { +// return nil +// } +// b = res.String() +// ltTextArr := ltText.FindStringSubmatch(b) +// if len(ltTextArr) > 0 { +// lt = ltTextArr[1] +// break +// } else { +// <-time.After(time.Second) +// } +// } +// if lt == "" { +// return fmt.Errorf("get page: %s \nstatus: %d \nrequest url: %s\nredirect url: %s", +// b, res.StatusCode(), res.RawResponse.Request.URL.String(), res.Header().Get("location")) +// } +// captchaToken := regexp.MustCompile(`captchaToken' value='(.+?)'`).FindStringSubmatch(b)[1] +// returnUrl := regexp.MustCompile(`returnUrl = '(.+?)'`).FindStringSubmatch(b)[1] +// paramId := regexp.MustCompile(`paramId = "(.+?)"`).FindStringSubmatch(b)[1] +// //reqId := regexp.MustCompile(`reqId = "(.+?)"`).FindStringSubmatch(b)[1] +// jRsakey := regexp.MustCompile(`j_rsaKey" value="(\S+)"`).FindStringSubmatch(b)[1] +// vCodeID := regexp.MustCompile(`picCaptcha\.do\?token\=([A-Za-z0-9\&\=]+)`).FindStringSubmatch(b)[1] +// vCodeRS := "" +// if vCodeID != "" { +// // need ValidateCode +// log.Debugf("try to identify verification codes") +// timeStamp := strconv.FormatInt(time.Now().UnixNano()/1e6, 10) +// u := "https://open.e.189.cn/api/logbox/oauth2/picCaptcha.do?token=" + vCodeID + timeStamp +// imgRes, err := d.client.R().SetHeaders(map[string]string{ +// "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:74.0) Gecko/20100101 Firefox/76.0", +// "Referer": "https://open.e.189.cn/api/logbox/oauth2/unifyAccountLogin.do", +// "Sec-Fetch-Dest": "image", +// "Sec-Fetch-Mode": "no-cors", +// "Sec-Fetch-Site": "same-origin", +// }).Get(u) +// if err != nil { +// return err +// } +// // Enter the verification code manually +// //err = message.GetMessenger().WaitSend(message.Message{ +// // Type: "image", +// // Content: "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgRes.Body()), +// //}, 10) +// //if err != nil { +// // return err +// //} +// //vCodeRS, err = message.GetMessenger().WaitReceive(30) +// // use ocr api +// vRes, err := base.RestyClient.R().SetMultipartField( +// "image", "validateCode.png", "image/png", bytes.NewReader(imgRes.Body())). +// Post(setting.GetStr(conf.OcrApi)) +// if err != nil { +// return err +// } +// if jsoniter.Get(vRes.Body(), "status").ToInt() != 200 { +// return errors.New("ocr error:" + jsoniter.Get(vRes.Body(), "msg").ToString()) +// } +// vCodeRS = jsoniter.Get(vRes.Body(), "result").ToString() +// log.Debugln("code: ", vCodeRS) +// } +// userRsa := RsaEncode([]byte(d.Username), jRsakey, true) +// passwordRsa := RsaEncode([]byte(d.Password), jRsakey, true) +// url = "https://open.e.189.cn/api/logbox/oauth2/loginSubmit.do" +// var loginResp LoginResp +// res, err = d.client.R(). +// SetHeaders(map[string]string{ +// "lt": lt, +// "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36", +// "Referer": "https://open.e.189.cn/", +// "accept": "application/json;charset=UTF-8", +// }).SetFormData(map[string]string{ +// "appKey": "cloud", +// "accountType": "01", +// "userName": "{RSA}" + userRsa, +// "password": "{RSA}" + passwordRsa, +// "validateCode": vCodeRS, +// "captchaToken": captchaToken, +// "returnUrl": returnUrl, +// "mailSuffix": "@pan.cn", +// "paramId": paramId, +// "clientType": "10010", +// "dynamicCheck": "FALSE", +// "cb_SaveName": "1", +// "isOauth2": "false", +// }).Post(url) +// if err != nil { +// return err +// } +// err = utils.Json.Unmarshal(res.Body(), &loginResp) +// if err != nil { +// log.Error(err.Error()) +// return err +// } +// if loginResp.Result != 0 { +// return fmt.Errorf(loginResp.Msg) +// } +// _, err = d.client.R().Get(loginResp.ToUrl) +// return err +//} + +func (d *Cloud189) request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + var e Error + req := d.client.R().SetError(&e). + SetHeader("Accept", "application/json;charset=UTF-8"). + SetQueryParams(map[string]string{ + "noCache": random(), + }) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + //log.Debug(res.String()) + if e.ErrorCode != "" { + if e.ErrorCode == "InvalidSessionKey" { + err = d.newLogin() + if err != nil { + return nil, err + } + return d.request(url, method, callback, resp) + } + } + if jsoniter.Get(res.Body(), "res_code").ToInt() != 0 { + err = errors.New(jsoniter.Get(res.Body(), "res_message").ToString()) + } + return res.Body(), err +} + +func (d *Cloud189) getFiles(fileId string) ([]model.Obj, error) { + res := make([]model.Obj, 0) + pageNum := 1 + for { + var resp Files + _, err := d.request("https://cloud.189.cn/api/open/file/listFiles.action", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + //"noCache": random(), + "pageSize": "60", + "pageNum": strconv.Itoa(pageNum), + "mediaType": "0", + "folderId": fileId, + "iconOption": "5", + "orderBy": "lastOpTime", //account.OrderBy + "descending": "true", //account.OrderDirection + }) + }, &resp) + if err != nil { + return nil, err + } + if resp.FileListAO.Count == 0 { + break + } + for _, folder := range resp.FileListAO.FolderList { + lastOpTime := utils.MustParseCNTime(folder.LastOpTime) + res = append(res, &model.Object{ + ID: strconv.FormatInt(folder.Id, 10), + Name: folder.Name, + Modified: lastOpTime, + IsFolder: true, + }) + } + for _, file := range resp.FileListAO.FileList { + lastOpTime := utils.MustParseCNTime(file.LastOpTime) + res = append(res, &model.ObjThumb{ + Object: model.Object{ + ID: strconv.FormatInt(file.Id, 10), + Name: file.Name, + Modified: lastOpTime, + Size: file.Size, + }, + Thumbnail: model.Thumbnail{Thumbnail: file.Icon.SmallUrl}, + }) + } + pageNum++ + } + return res, nil +} + +func (d *Cloud189) oldUpload(dstDir model.Obj, file model.FileStreamer) error { + res, err := d.client.R().SetMultipartFormData(map[string]string{ + "parentId": dstDir.GetID(), + "sessionKey": "??", + "opertype": "1", + "fname": file.GetName(), + }).SetMultipartField("Filedata", file.GetName(), file.GetMimetype(), file).Post("https://hb02.upload.cloud.189.cn/v1/DCIWebUploadAction") + if err != nil { + return err + } + if utils.Json.Get(res.Body(), "MD5").ToString() != "" { + return nil + } + log.Debugf(res.String()) + return errors.New(res.String()) +} + +func (d *Cloud189) getSessionKey() (string, error) { + resp, err := d.request("https://cloud.189.cn/v2/getUserBriefInfo.action", http.MethodGet, nil, nil) + if err != nil { + return "", err + } + sessionKey := utils.Json.Get(resp, "sessionKey").ToString() + return sessionKey, nil +} + +func (d *Cloud189) getResKey() (string, string, error) { + now := time.Now().UnixMilli() + if d.rsa.Expire > now { + return d.rsa.PubKey, d.rsa.PkId, nil + } + resp, err := d.request("https://cloud.189.cn/api/security/generateRsaKey.action", http.MethodGet, nil, nil) + if err != nil { + return "", "", err + } + pubKey, pkId := utils.Json.Get(resp, "pubKey").ToString(), utils.Json.Get(resp, "pkId").ToString() + d.rsa.PubKey, d.rsa.PkId = pubKey, pkId + d.rsa.Expire = utils.Json.Get(resp, "expire").ToInt64() + return pubKey, pkId, nil +} + +func (d *Cloud189) uploadRequest(uri string, form map[string]string, resp interface{}) ([]byte, error) { + c := strconv.FormatInt(time.Now().UnixMilli(), 10) + r := Random("xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx") + l := Random("xxxxxxxxxxxx4xxxyxxxxxxxxxxxxxxx") + l = l[0 : 16+int(16*myrand.Rand.Float32())] + + e := qs(form) + data := AesEncrypt([]byte(e), []byte(l[0:16])) + h := hex.EncodeToString(data) + + sessionKey := d.sessionKey + signature := hmacSha1(fmt.Sprintf("SessionKey=%s&Operate=GET&RequestURI=%s&Date=%s¶ms=%s", sessionKey, uri, c, h), l) + + pubKey, pkId, err := d.getResKey() + if err != nil { + return nil, err + } + b := RsaEncode([]byte(l), pubKey, false) + req := d.client.R().SetHeaders(map[string]string{ + "accept": "application/json;charset=UTF-8", + "SessionKey": sessionKey, + "Signature": signature, + "X-Request-Date": c, + "X-Request-ID": r, + "EncryptionText": b, + "PkId": pkId, + }) + if resp != nil { + req.SetResult(resp) + } + res, err := req.Get("https://upload.cloud.189.cn" + uri + "?params=" + h) + if err != nil { + return nil, err + } + data = res.Body() + if utils.Json.Get(data, "code").ToString() != "SUCCESS" { + return nil, errors.New(uri + "---" + jsoniter.Get(data, "msg").ToString()) + } + return data, nil +} + +func (d *Cloud189) newUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { + sessionKey, err := d.getSessionKey() + if err != nil { + return err + } + d.sessionKey = sessionKey + const DEFAULT int64 = 10485760 + var count = int64(math.Ceil(float64(file.GetSize()) / float64(DEFAULT))) + + res, err := d.uploadRequest("/person/initMultiUpload", map[string]string{ + "parentFolderId": dstDir.GetID(), + "fileName": encode(file.GetName()), + "fileSize": strconv.FormatInt(file.GetSize(), 10), + "sliceSize": strconv.FormatInt(DEFAULT, 10), + "lazyCheck": "1", + }, nil) + if err != nil { + return err + } + uploadFileId := jsoniter.Get(res, "data", "uploadFileId").ToString() + //_, err = d.uploadRequest("/person/getUploadedPartsInfo", map[string]string{ + // "uploadFileId": uploadFileId, + //}, nil) + var finish int64 = 0 + var i int64 + var byteSize int64 + md5s := make([]string, 0) + md5Sum := md5.New() + for i = 1; i <= count; i++ { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + byteSize = file.GetSize() - finish + if DEFAULT < byteSize { + byteSize = DEFAULT + } + //log.Debugf("%d,%d", byteSize, finish) + byteData := make([]byte, byteSize) + n, err := io.ReadFull(file, byteData) + //log.Debug(err, n) + if err != nil { + return err + } + finish += int64(n) + md5Bytes := getMd5(byteData) + md5Hex := hex.EncodeToString(md5Bytes) + md5Base64 := base64.StdEncoding.EncodeToString(md5Bytes) + md5s = append(md5s, strings.ToUpper(md5Hex)) + md5Sum.Write(byteData) + var resp UploadUrlsResp + res, err = d.uploadRequest("/person/getMultiUploadUrls", map[string]string{ + "partInfo": fmt.Sprintf("%s-%s", strconv.FormatInt(i, 10), md5Base64), + "uploadFileId": uploadFileId, + }, &resp) + if err != nil { + return err + } + uploadData := resp.UploadUrls["partNumber_"+strconv.FormatInt(i, 10)] + log.Debugf("uploadData: %+v", uploadData) + requestURL := uploadData.RequestURL + uploadHeaders := strings.Split(decodeURIComponent(uploadData.RequestHeader), "&") + req, err := http.NewRequest(http.MethodPut, requestURL, bytes.NewReader(byteData)) + if err != nil { + return err + } + req = req.WithContext(ctx) + for _, v := range uploadHeaders { + i := strings.Index(v, "=") + req.Header.Set(v[0:i], v[i+1:]) + } + r, err := base.HttpClient.Do(req) + log.Debugf("%+v %+v", r, r.Request.Header) + r.Body.Close() + if err != nil { + return err + } + up(float64(i) * 100 / float64(count)) + } + fileMd5 := hex.EncodeToString(md5Sum.Sum(nil)) + sliceMd5 := fileMd5 + if file.GetSize() > DEFAULT { + sliceMd5 = utils.GetMD5EncodeStr(strings.Join(md5s, "\n")) + } + res, err = d.uploadRequest("/person/commitMultiUploadFile", map[string]string{ + "uploadFileId": uploadFileId, + "fileMd5": fileMd5, + "sliceMd5": sliceMd5, + "lazyCheck": "1", + "opertype": "3", + }, nil) + return err +} diff --git a/drivers/189pc/driver.go b/drivers/189pc/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..9c01a50fd86f6ade48ef9a439dfd1c85185cb697 --- /dev/null +++ b/drivers/189pc/driver.go @@ -0,0 +1,361 @@ +package _189pc + +import ( + "container/ring" + "context" + "net/http" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +type Cloud189PC struct { + model.Storage + Addition + + identity string + + client *resty.Client + + loginParam *LoginParam + tokenInfo *AppSessionResp + + uploadThread int + + familyTransferFolder *ring.Ring + cleanFamilyTransferFile func() + + storageConfig driver.Config +} + +func (y *Cloud189PC) Config() driver.Config { + if y.storageConfig.Name == "" { + y.storageConfig = config + } + return y.storageConfig +} + +func (y *Cloud189PC) GetAddition() driver.Additional { + return &y.Addition +} + +func (y *Cloud189PC) Init(ctx context.Context) (err error) { + // 兼容旧上传接口 + y.storageConfig.NoOverwriteUpload = y.isFamily() && (y.Addition.RapidUpload || y.Addition.UploadMethod == "old") + + // 处理个人云和家庭云参数 + if y.isFamily() && y.RootFolderID == "-11" { + y.RootFolderID = "" + } + if !y.isFamily() && y.RootFolderID == "" { + y.RootFolderID = "-11" + } + + // 限制上传线程数 + y.uploadThread, _ = strconv.Atoi(y.UploadThread) + if y.uploadThread < 1 || y.uploadThread > 32 { + y.uploadThread, y.UploadThread = 3, "3" + } + + // 初始化请求客户端 + if y.client == nil { + y.client = base.NewRestyClient().SetHeaders(map[string]string{ + "Accept": "application/json;charset=UTF-8", + "Referer": WEB_URL, + }) + } + + // 避免重复登陆 + identity := utils.GetMD5EncodeStr(y.Username + y.Password) + if !y.isLogin() || y.identity != identity { + y.identity = identity + if err = y.login(); err != nil { + return + } + } + + // 处理家庭云ID + if y.FamilyID == "" { + if y.FamilyID, err = y.getFamilyID(); err != nil { + return err + } + } + + // 创建中转文件夹,防止重名文件 + if y.FamilyTransfer { + if y.familyTransferFolder, err = y.createFamilyTransferFolder(32); err != nil { + return err + } + } + + y.cleanFamilyTransferFile = utils.NewThrottle2(time.Minute, func() { + if err := y.cleanFamilyTransfer(context.TODO()); err != nil { + utils.Log.Errorf("cleanFamilyTransferFolderError:%s", err) + } + }) + return +} + +func (y *Cloud189PC) Drop(ctx context.Context) error { + return nil +} + +func (y *Cloud189PC) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + return y.getFiles(ctx, dir.GetID(), y.isFamily()) +} + +func (y *Cloud189PC) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var downloadUrl struct { + URL string `json:"fileDownloadUrl"` + } + + isFamily := y.isFamily() + fullUrl := API_URL + if isFamily { + fullUrl += "/family/file" + } + fullUrl += "/getFileDownloadUrl.action" + + _, err := y.get(fullUrl, func(r *resty.Request) { + r.SetContext(ctx) + r.SetQueryParam("fileId", file.GetID()) + if isFamily { + r.SetQueryParams(map[string]string{ + "familyId": y.FamilyID, + }) + } else { + r.SetQueryParams(map[string]string{ + "dt": "3", + "flag": "1", + }) + } + }, &downloadUrl, isFamily) + if err != nil { + return nil, err + } + + // 重定向获取真实链接 + downloadUrl.URL = strings.Replace(strings.ReplaceAll(downloadUrl.URL, "&", "&"), "http://", "https://", 1) + res, err := base.NoRedirectClient.R().SetContext(ctx).SetDoNotParseResponse(true).Get(downloadUrl.URL) + if err != nil { + return nil, err + } + defer res.RawBody().Close() + if res.StatusCode() == 302 { + downloadUrl.URL = res.Header().Get("location") + } + + like := &model.Link{ + URL: downloadUrl.URL, + Header: http.Header{ + "User-Agent": []string{base.UserAgent}, + }, + } + /* + // 获取链接有效时常 + strs := regexp.MustCompile(`(?i)expire[^=]*=([0-9]*)`).FindStringSubmatch(downloadUrl.URL) + if len(strs) == 2 { + timestamp, err := strconv.ParseInt(strs[1], 10, 64) + if err == nil { + expired := time.Duration(timestamp-time.Now().Unix()) * time.Second + like.Expiration = &expired + } + } + */ + return like, nil +} + +func (y *Cloud189PC) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + isFamily := y.isFamily() + fullUrl := API_URL + if isFamily { + fullUrl += "/family/file" + } + fullUrl += "/createFolder.action" + + var newFolder Cloud189Folder + _, err := y.post(fullUrl, func(req *resty.Request) { + req.SetContext(ctx) + req.SetQueryParams(map[string]string{ + "folderName": dirName, + "relativePath": "", + }) + if isFamily { + req.SetQueryParams(map[string]string{ + "familyId": y.FamilyID, + "parentId": parentDir.GetID(), + }) + } else { + req.SetQueryParams(map[string]string{ + "parentFolderId": parentDir.GetID(), + }) + } + }, &newFolder, isFamily) + if err != nil { + return nil, err + } + return &newFolder, nil +} + +func (y *Cloud189PC) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + isFamily := y.isFamily() + other := map[string]string{"targetFileName": dstDir.GetName()} + + resp, err := y.CreateBatchTask("MOVE", IF(isFamily, y.FamilyID, ""), dstDir.GetID(), other, BatchTaskInfo{ + FileId: srcObj.GetID(), + FileName: srcObj.GetName(), + IsFolder: BoolToNumber(srcObj.IsDir()), + }) + if err != nil { + return nil, err + } + if err = y.WaitBatchTask("MOVE", resp.TaskID, time.Millisecond*400); err != nil { + return nil, err + } + return srcObj, nil +} + +func (y *Cloud189PC) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + isFamily := y.isFamily() + queryParam := make(map[string]string) + fullUrl := API_URL + method := http.MethodPost + if isFamily { + fullUrl += "/family/file" + method = http.MethodGet + queryParam["familyId"] = y.FamilyID + } + + var newObj model.Obj + switch f := srcObj.(type) { + case *Cloud189File: + fullUrl += "/renameFile.action" + queryParam["fileId"] = srcObj.GetID() + queryParam["destFileName"] = newName + newObj = &Cloud189File{Icon: f.Icon} // 复用预览 + case *Cloud189Folder: + fullUrl += "/renameFolder.action" + queryParam["folderId"] = srcObj.GetID() + queryParam["destFolderName"] = newName + newObj = &Cloud189Folder{} + default: + return nil, errs.NotSupport + } + + _, err := y.request(fullUrl, method, func(req *resty.Request) { + req.SetContext(ctx).SetQueryParams(queryParam) + }, nil, newObj, isFamily) + if err != nil { + return nil, err + } + return newObj, nil +} + +func (y *Cloud189PC) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + isFamily := y.isFamily() + other := map[string]string{"targetFileName": dstDir.GetName()} + + resp, err := y.CreateBatchTask("COPY", IF(isFamily, y.FamilyID, ""), dstDir.GetID(), other, BatchTaskInfo{ + FileId: srcObj.GetID(), + FileName: srcObj.GetName(), + IsFolder: BoolToNumber(srcObj.IsDir()), + }) + + if err != nil { + return err + } + return y.WaitBatchTask("COPY", resp.TaskID, time.Second) +} + +func (y *Cloud189PC) Remove(ctx context.Context, obj model.Obj) error { + isFamily := y.isFamily() + + resp, err := y.CreateBatchTask("DELETE", IF(isFamily, y.FamilyID, ""), "", nil, BatchTaskInfo{ + FileId: obj.GetID(), + FileName: obj.GetName(), + IsFolder: BoolToNumber(obj.IsDir()), + }) + if err != nil { + return err + } + // 批量任务数量限制,过快会导致无法删除 + return y.WaitBatchTask("DELETE", resp.TaskID, time.Millisecond*200) +} + +func (y *Cloud189PC) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (newObj model.Obj, err error) { + overwrite := true + isFamily := y.isFamily() + + // 响应时间长,按需启用 + if y.Addition.RapidUpload && !stream.IsForceStreamUpload() { + if newObj, err := y.RapidUpload(ctx, dstDir, stream, isFamily, overwrite); err == nil { + return newObj, nil + } + } + + uploadMethod := y.UploadMethod + if stream.IsForceStreamUpload() { + uploadMethod = "stream" + } + + // 旧版上传家庭云也有限制 + if uploadMethod == "old" { + return y.OldUpload(ctx, dstDir, stream, up, isFamily, overwrite) + } + + // 开启家庭云转存 + if !isFamily && y.FamilyTransfer { + // 修改上传目标为家庭云文件夹 + transferDstDir := dstDir + dstDir = (y.familyTransferFolder.Value).(*Cloud189Folder) + y.familyTransferFolder = y.familyTransferFolder.Next() + + isFamily = true + overwrite = false + + defer func() { + if newObj != nil { + // 批量任务有概率删不掉 + y.cleanFamilyTransferFile() + + // 转存家庭云文件到个人云 + err = y.SaveFamilyFileToPersonCloud(context.TODO(), y.FamilyID, newObj, transferDstDir, true) + + task := BatchTaskInfo{ + FileId: newObj.GetID(), + FileName: newObj.GetName(), + IsFolder: BoolToNumber(newObj.IsDir()), + } + + // 删除源文件 + if resp, err := y.CreateBatchTask("DELETE", y.FamilyID, "", nil, task); err == nil { + y.WaitBatchTask("DELETE", resp.TaskID, time.Second) + // 永久删除 + if resp, err := y.CreateBatchTask("CLEAR_RECYCLE", y.FamilyID, "", nil, task); err == nil { + y.WaitBatchTask("CLEAR_RECYCLE", resp.TaskID, time.Second) + } + } + newObj = nil + } + }() + } + + switch uploadMethod { + case "rapid": + return y.FastUpload(ctx, dstDir, stream, up, isFamily, overwrite) + case "stream": + if stream.GetSize() == 0 { + return y.FastUpload(ctx, dstDir, stream, up, isFamily, overwrite) + } + fallthrough + default: + return y.StreamUpload(ctx, dstDir, stream, up, isFamily, overwrite) + } +} diff --git a/drivers/189pc/help.go b/drivers/189pc/help.go new file mode 100644 index 0000000000000000000000000000000000000000..49f957fab1dd3dce72b9ea7409191cdd02aef0f7 --- /dev/null +++ b/drivers/189pc/help.go @@ -0,0 +1,210 @@ +package _189pc + +import ( + "bytes" + "crypto/aes" + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "encoding/hex" + "encoding/pem" + "encoding/xml" + "fmt" + "math" + "net/http" + "regexp" + "strings" + "time" + + "github.com/alist-org/alist/v3/pkg/utils/random" +) + +func clientSuffix() map[string]string { + rand := random.Rand + return map[string]string{ + "clientType": PC, + "version": VERSION, + "channelId": CHANNEL_ID, + "rand": fmt.Sprintf("%d_%d", rand.Int63n(1e5), rand.Int63n(1e10)), + } +} + +// 带params的SignatureOfHmac HMAC签名 +func signatureOfHmac(sessionSecret, sessionKey, operate, fullUrl, dateOfGmt, param string) string { + urlpath := regexp.MustCompile(`://[^/]+((/[^/\s?#]+)*)`).FindStringSubmatch(fullUrl)[1] + mac := hmac.New(sha1.New, []byte(sessionSecret)) + data := fmt.Sprintf("SessionKey=%s&Operate=%s&RequestURI=%s&Date=%s", sessionKey, operate, urlpath, dateOfGmt) + if param != "" { + data += fmt.Sprintf("¶ms=%s", param) + } + mac.Write([]byte(data)) + return strings.ToUpper(hex.EncodeToString(mac.Sum(nil))) +} + +// RAS 加密用户名密码 +func RsaEncrypt(publicKey, origData string) string { + block, _ := pem.Decode([]byte(publicKey)) + pubInterface, _ := x509.ParsePKIXPublicKey(block.Bytes) + data, _ := rsa.EncryptPKCS1v15(rand.Reader, pubInterface.(*rsa.PublicKey), []byte(origData)) + return strings.ToUpper(hex.EncodeToString(data)) +} + +// aes 加密params +func AesECBEncrypt(data, key string) string { + block, _ := aes.NewCipher([]byte(key)) + paddingData := PKCS7Padding([]byte(data), block.BlockSize()) + decrypted := make([]byte, len(paddingData)) + size := block.BlockSize() + for src, dst := paddingData, decrypted; len(src) > 0; src, dst = src[size:], dst[size:] { + block.Encrypt(dst[:size], src[:size]) + } + return strings.ToUpper(hex.EncodeToString(decrypted)) +} + +func PKCS7Padding(ciphertext []byte, blockSize int) []byte { + padding := blockSize - len(ciphertext)%blockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(ciphertext, padtext...) +} + +// 获取http规范的时间 +func getHttpDateStr() string { + return time.Now().UTC().Format(http.TimeFormat) +} + +// 时间戳 +func timestamp() int64 { + return time.Now().UTC().UnixNano() / 1e6 +} + +func MustParseTime(str string) *time.Time { + lastOpTime, _ := time.ParseInLocation("2006-01-02 15:04:05 -07", str+" +08", time.Local) + return &lastOpTime +} + +type Time time.Time + +func (t *Time) UnmarshalJSON(b []byte) error { return t.Unmarshal(b) } +func (t *Time) UnmarshalXML(e *xml.Decoder, ee xml.StartElement) error { + b, err := e.Token() + if err != nil { + return err + } + if b, ok := b.(xml.CharData); ok { + if err = t.Unmarshal(b); err != nil { + return err + } + } + return e.Skip() +} +func (t *Time) Unmarshal(b []byte) error { + bs := strings.Trim(string(b), "\"") + var v time.Time + var err error + for _, f := range []string{"2006-01-02 15:04:05 -07", "Jan 2, 2006 15:04:05 PM -07"} { + v, err = time.ParseInLocation(f, bs+" +08", time.Local) + if err == nil { + break + } + } + *t = Time(v) + return err +} + +type String string + +func (t *String) UnmarshalJSON(b []byte) error { return t.Unmarshal(b) } +func (t *String) UnmarshalXML(e *xml.Decoder, ee xml.StartElement) error { + b, err := e.Token() + if err != nil { + return err + } + if b, ok := b.(xml.CharData); ok { + if err = t.Unmarshal(b); err != nil { + return err + } + } + return e.Skip() +} +func (s *String) Unmarshal(b []byte) error { + *s = String(bytes.Trim(b, "\"")) + return nil +} + +func toFamilyOrderBy(o string) string { + switch o { + case "filename": + return "1" + case "filesize": + return "2" + case "lastOpTime": + return "3" + default: + return "1" + } +} + +func toDesc(o string) string { + switch o { + case "desc": + return "true" + case "asc": + fallthrough + default: + return "false" + } +} + +func ParseHttpHeader(str string) map[string]string { + header := make(map[string]string) + for _, value := range strings.Split(str, "&") { + if k, v, found := strings.Cut(value, "="); found { + header[k] = v + } + } + return header +} + +func MustString(str string, err error) string { + return str +} + +func BoolToNumber(b bool) int { + if b { + return 1 + } + return 0 +} + +// 计算分片大小 +// 对分片数量有限制 +// 10MIB 20 MIB 999片 +// 50MIB 60MIB 70MIB 80MIB ∞MIB 1999片 +func partSize(size int64) int64 { + const DEFAULT = 1024 * 1024 * 10 // 10MIB + if size > DEFAULT*2*999 { + return int64(math.Max(math.Ceil((float64(size)/1999) /*=单个切片大小*/ /float64(DEFAULT)) /*=倍率*/, 5) * DEFAULT) + } + if size > DEFAULT*999 { + return DEFAULT * 2 // 20MIB + } + return DEFAULT +} + +func isBool(bs ...bool) bool { + for _, b := range bs { + if b { + return true + } + } + return false +} + +func IF[V any](o bool, t V, f V) V { + if o { + return t + } + return f +} diff --git a/drivers/189pc/meta.go b/drivers/189pc/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..1891c5c0ccdf0622929a39742b373a83a2d78e15 --- /dev/null +++ b/drivers/189pc/meta.go @@ -0,0 +1,34 @@ +package _189pc + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + Username string `json:"username" required:"true"` + Password string `json:"password" required:"true"` + VCode string `json:"validate_code"` + driver.RootID + OrderBy string `json:"order_by" type:"select" options:"filename,filesize,lastOpTime" default:"filename"` + OrderDirection string `json:"order_direction" type:"select" options:"asc,desc" default:"asc"` + Type string `json:"type" type:"select" options:"personal,family" default:"personal"` + FamilyID string `json:"family_id"` + UploadMethod string `json:"upload_method" type:"select" options:"stream,rapid,old" default:"stream"` + UploadThread string `json:"upload_thread" default:"3" help:"1<=thread<=32"` + FamilyTransfer bool `json:"family_transfer"` + RapidUpload bool `json:"rapid_upload"` + NoUseOcr bool `json:"no_use_ocr"` +} + +var config = driver.Config{ + Name: "189CloudPC", + DefaultRoot: "-11", + CheckStatus: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Cloud189PC{} + }) +} diff --git a/drivers/189pc/types.go b/drivers/189pc/types.go new file mode 100644 index 0000000000000000000000000000000000000000..2e9ed4c203d82db2a87fe4243cfb58fc19fb369e --- /dev/null +++ b/drivers/189pc/types.go @@ -0,0 +1,398 @@ +package _189pc + +import ( + "encoding/xml" + "fmt" + "sort" + "strings" + "time" + + "github.com/alist-org/alist/v3/pkg/utils" +) + +// 居然有四种返回方式 +type RespErr struct { + ResCode any `json:"res_code"` // int or string + ResMessage string `json:"res_message"` + + Error_ string `json:"error"` + + XMLName xml.Name `xml:"error"` + Code string `json:"code" xml:"code"` + Message string `json:"message" xml:"message"` + Msg string `json:"msg"` + + ErrorCode string `json:"errorCode"` + ErrorMsg string `json:"errorMsg"` +} + +func (e *RespErr) HasError() bool { + switch v := e.ResCode.(type) { + case int, int64, int32: + return v != 0 + case string: + return e.ResCode != "" + } + return (e.Code != "" && e.Code != "SUCCESS") || e.ErrorCode != "" || e.Error_ != "" +} + +func (e *RespErr) Error() string { + switch v := e.ResCode.(type) { + case int, int64, int32: + if v != 0 { + return fmt.Sprintf("res_code: %d ,res_msg: %s", v, e.ResMessage) + } + case string: + if e.ResCode != "" { + return fmt.Sprintf("res_code: %s ,res_msg: %s", e.ResCode, e.ResMessage) + } + } + + if e.Code != "" && e.Code != "SUCCESS" { + if e.Msg != "" { + return fmt.Sprintf("code: %s ,msg: %s", e.Code, e.Msg) + } + if e.Message != "" { + return fmt.Sprintf("code: %s ,msg: %s", e.Code, e.Message) + } + return "code: " + e.Code + } + + if e.ErrorCode != "" { + return fmt.Sprintf("err_code: %s ,err_msg: %s", e.ErrorCode, e.ErrorMsg) + } + + if e.Error_ != "" { + return fmt.Sprintf("error: %s ,message: %s", e.ErrorCode, e.Message) + } + return "" +} + +// 登陆需要的参数 +type LoginParam struct { + // 加密后的用户名和密码 + RsaUsername string + RsaPassword string + + // rsa密钥 + jRsaKey string + + // 请求头参数 + Lt string + ReqId string + + // 表单参数 + ParamId string + + // 验证码 + CaptchaToken string +} + +// 登陆加密相关 +type EncryptConfResp struct { + Result int `json:"result"` + Data struct { + UpSmsOn string `json:"upSmsOn"` + Pre string `json:"pre"` + PreDomain string `json:"preDomain"` + PubKey string `json:"pubKey"` + } `json:"data"` +} + +type LoginResp struct { + Msg string `json:"msg"` + Result int `json:"result"` + ToUrl string `json:"toUrl"` +} + +// 刷新session返回 +type UserSessionResp struct { + ResCode int `json:"res_code"` + ResMessage string `json:"res_message"` + + LoginName string `json:"loginName"` + + KeepAlive int `json:"keepAlive"` + GetFileDiffSpan int `json:"getFileDiffSpan"` + GetUserInfoSpan int `json:"getUserInfoSpan"` + + // 个人云 + SessionKey string `json:"sessionKey"` + SessionSecret string `json:"sessionSecret"` + // 家庭云 + FamilySessionKey string `json:"familySessionKey"` + FamilySessionSecret string `json:"familySessionSecret"` +} + +// 登录返回 +type AppSessionResp struct { + UserSessionResp + + IsSaveName string `json:"isSaveName"` + + // 会话刷新Token + AccessToken string `json:"accessToken"` + //Token刷新 + RefreshToken string `json:"refreshToken"` +} + +// 家庭云账户 +type FamilyInfoListResp struct { + FamilyInfoResp []FamilyInfoResp `json:"familyInfoResp"` +} +type FamilyInfoResp struct { + Count int `json:"count"` + CreateTime string `json:"createTime"` + FamilyID int64 `json:"familyId"` + RemarkName string `json:"remarkName"` + Type int `json:"type"` + UseFlag int `json:"useFlag"` + UserRole int `json:"userRole"` +} + +/*文件部分*/ +// 文件 +type Cloud189File struct { + ID String `json:"id"` + Name string `json:"name"` + Size int64 `json:"size"` + Md5 string `json:"md5"` + + LastOpTime Time `json:"lastOpTime"` + CreateDate Time `json:"createDate"` + Icon struct { + //iconOption 5 + SmallUrl string `json:"smallUrl"` + LargeUrl string `json:"largeUrl"` + + // iconOption 10 + Max600 string `json:"max600"` + MediumURL string `json:"mediumUrl"` + } `json:"icon"` + + // Orientation int64 `json:"orientation"` + // FileCata int64 `json:"fileCata"` + // MediaType int `json:"mediaType"` + // Rev string `json:"rev"` + // StarLabel int64 `json:"starLabel"` +} + +func (c *Cloud189File) CreateTime() time.Time { + return time.Time(c.CreateDate) +} + +func (c *Cloud189File) GetHash() utils.HashInfo { + return utils.NewHashInfo(utils.MD5, c.Md5) +} + +func (c *Cloud189File) GetSize() int64 { return c.Size } +func (c *Cloud189File) GetName() string { return c.Name } +func (c *Cloud189File) ModTime() time.Time { return time.Time(c.LastOpTime) } +func (c *Cloud189File) IsDir() bool { return false } +func (c *Cloud189File) GetID() string { return string(c.ID) } +func (c *Cloud189File) GetPath() string { return "" } +func (c *Cloud189File) Thumb() string { return c.Icon.SmallUrl } + +// 文件夹 +type Cloud189Folder struct { + ID String `json:"id"` + ParentID int64 `json:"parentId"` + Name string `json:"name"` + + LastOpTime Time `json:"lastOpTime"` + CreateDate Time `json:"createDate"` + + // FileListSize int64 `json:"fileListSize"` + // FileCount int64 `json:"fileCount"` + // FileCata int64 `json:"fileCata"` + // Rev string `json:"rev"` + // StarLabel int64 `json:"starLabel"` +} + +func (c *Cloud189Folder) CreateTime() time.Time { + return time.Time(c.CreateDate) +} + +func (c *Cloud189Folder) GetHash() utils.HashInfo { + return utils.HashInfo{} +} + +func (c *Cloud189Folder) GetSize() int64 { return 0 } +func (c *Cloud189Folder) GetName() string { return c.Name } +func (c *Cloud189Folder) ModTime() time.Time { return time.Time(c.LastOpTime) } +func (c *Cloud189Folder) IsDir() bool { return true } +func (c *Cloud189Folder) GetID() string { return string(c.ID) } +func (c *Cloud189Folder) GetPath() string { return "" } + +type Cloud189FilesResp struct { + //ResCode int `json:"res_code"` + //ResMessage string `json:"res_message"` + FileListAO struct { + Count int `json:"count"` + FileList []Cloud189File `json:"fileList"` + FolderList []Cloud189Folder `json:"folderList"` + } `json:"fileListAO"` +} + +// TaskInfo 任务信息 +type BatchTaskInfo struct { + // FileId 文件ID + FileId string `json:"fileId"` + // FileName 文件名 + FileName string `json:"fileName"` + // IsFolder 是否是文件夹,0-否,1-是 + IsFolder int `json:"isFolder"` + // SrcParentId 文件所在父目录ID + SrcParentId string `json:"srcParentId,omitempty"` + + /* 冲突管理 */ + // 1 -> 跳过 2 -> 保留 3 -> 覆盖 + DealWay int `json:"dealWay,omitempty"` + IsConflict int `json:"isConflict,omitempty"` +} + +/* 上传部分 */ +type InitMultiUploadResp struct { + //Code string `json:"code"` + Data struct { + UploadType int `json:"uploadType"` + UploadHost string `json:"uploadHost"` + UploadFileID string `json:"uploadFileId"` + FileDataExists int `json:"fileDataExists"` + } `json:"data"` +} +type UploadUrlsResp struct { + Code string `json:"code"` + Data map[string]UploadUrlsData `json:"uploadUrls"` +} +type UploadUrlsData struct { + RequestURL string `json:"requestURL"` + RequestHeader string `json:"requestHeader"` +} + +type UploadUrlInfo struct { + PartNumber int + Headers map[string]string + UploadUrlsData +} + +type UploadProgress struct { + UploadInfo InitMultiUploadResp + UploadParts []string +} + +/* 第二种上传方式 */ +type CreateUploadFileResp struct { + // 上传文件请求ID + UploadFileId int64 `json:"uploadFileId"` + // 上传文件数据的URL路径 + FileUploadUrl string `json:"fileUploadUrl"` + // 上传文件完成后确认路径 + FileCommitUrl string `json:"fileCommitUrl"` + // 文件是否已存在云盘中,0-未存在,1-已存在 + FileDataExists int `json:"fileDataExists"` +} + +type GetUploadFileStatusResp struct { + CreateUploadFileResp + + // 已上传的大小 + DataSize int64 `json:"dataSize"` + Size int64 `json:"size"` +} + +func (r *GetUploadFileStatusResp) GetSize() int64 { + return r.DataSize + r.Size +} + +type CommitMultiUploadFileResp struct { + File struct { + UserFileID String `json:"userFileId"` + FileName string `json:"fileName"` + FileSize int64 `json:"fileSize"` + FileMd5 string `json:"fileMd5"` + CreateDate Time `json:"createDate"` + } `json:"file"` +} + +func (f *CommitMultiUploadFileResp) toFile() *Cloud189File { + return &Cloud189File{ + ID: f.File.UserFileID, + Name: f.File.FileName, + Size: f.File.FileSize, + Md5: f.File.FileMd5, + LastOpTime: f.File.CreateDate, + CreateDate: f.File.CreateDate, + } +} + +type OldCommitUploadFileResp struct { + XMLName xml.Name `xml:"file"` + ID String `xml:"id"` + Name string `xml:"name"` + Size int64 `xml:"size"` + Md5 string `xml:"md5"` + CreateDate Time `xml:"createDate"` +} + +func (f *OldCommitUploadFileResp) toFile() *Cloud189File { + return &Cloud189File{ + ID: f.ID, + Name: f.Name, + Size: f.Size, + Md5: f.Md5, + CreateDate: f.CreateDate, + LastOpTime: f.CreateDate, + } +} + +type CreateBatchTaskResp struct { + TaskID string `json:"taskId"` +} + +type BatchTaskStateResp struct { + FailedCount int `json:"failedCount"` + Process int `json:"process"` + SkipCount int `json:"skipCount"` + SubTaskCount int `json:"subTaskCount"` + SuccessedCount int `json:"successedCount"` + SuccessedFileIDList []int64 `json:"successedFileIdList"` + TaskID string `json:"taskId"` + TaskStatus int `json:"taskStatus"` //1 初始化 2 存在冲突 3 执行中,4 完成 +} + +type BatchTaskConflictTaskInfoResp struct { + SessionKey string `json:"sessionKey"` + TargetFolderID int `json:"targetFolderId"` + TaskID string `json:"taskId"` + TaskInfos []BatchTaskInfo + TaskType int `json:"taskType"` +} + +/* query 加密参数*/ +type Params map[string]string + +func (p Params) Set(k, v string) { + p[k] = v +} + +func (p Params) Encode() string { + if p == nil { + return "" + } + var buf strings.Builder + keys := make([]string, 0, len(p)) + for k := range p { + keys = append(keys, k) + } + sort.Strings(keys) + for i := range keys { + if buf.Len() > 0 { + buf.WriteByte('&') + } + buf.WriteString(keys[i]) + buf.WriteByte('=') + buf.WriteString(p[keys[i]]) + } + return buf.String() +} diff --git a/drivers/189pc/utils.go b/drivers/189pc/utils.go new file mode 100644 index 0000000000000000000000000000000000000000..f5a44455d2e183482aeafd6a39d5962367621f49 --- /dev/null +++ b/drivers/189pc/utils.go @@ -0,0 +1,1144 @@ +package _189pc + +import ( + "bytes" + "container/ring" + "context" + "crypto/md5" + "encoding/base64" + "encoding/hex" + "encoding/xml" + "fmt" + "io" + "math" + "net/http" + "net/http/cookiejar" + "net/url" + "regexp" + "sort" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/errgroup" + "github.com/alist-org/alist/v3/pkg/utils" + + "github.com/avast/retry-go" + "github.com/go-resty/resty/v2" + "github.com/google/uuid" + jsoniter "github.com/json-iterator/go" + "github.com/pkg/errors" +) + +const ( + ACCOUNT_TYPE = "02" + APP_ID = "8025431004" + CLIENT_TYPE = "10020" + VERSION = "6.2" + + WEB_URL = "https://cloud.189.cn" + AUTH_URL = "https://open.e.189.cn" + API_URL = "https://api.cloud.189.cn" + UPLOAD_URL = "https://upload.cloud.189.cn" + + RETURN_URL = "https://m.cloud.189.cn/zhuanti/2020/loginErrorPc/index.html" + + PC = "TELEPC" + MAC = "TELEMAC" + + CHANNEL_ID = "web_cloud.189.cn" +) + +func (y *Cloud189PC) SignatureHeader(url, method, params string, isFamily bool) map[string]string { + dateOfGmt := getHttpDateStr() + sessionKey := y.tokenInfo.SessionKey + sessionSecret := y.tokenInfo.SessionSecret + if isFamily { + sessionKey = y.tokenInfo.FamilySessionKey + sessionSecret = y.tokenInfo.FamilySessionSecret + } + + header := map[string]string{ + "Date": dateOfGmt, + "SessionKey": sessionKey, + "X-Request-ID": uuid.NewString(), + "Signature": signatureOfHmac(sessionSecret, sessionKey, method, url, dateOfGmt, params), + } + return header +} + +func (y *Cloud189PC) EncryptParams(params Params, isFamily bool) string { + sessionSecret := y.tokenInfo.SessionSecret + if isFamily { + sessionSecret = y.tokenInfo.FamilySessionSecret + } + if params != nil { + return AesECBEncrypt(params.Encode(), sessionSecret[:16]) + } + return "" +} + +func (y *Cloud189PC) request(url, method string, callback base.ReqCallback, params Params, resp interface{}, isFamily ...bool) ([]byte, error) { + req := y.client.R().SetQueryParams(clientSuffix()) + + // 设置params + paramsData := y.EncryptParams(params, isBool(isFamily...)) + if paramsData != "" { + req.SetQueryParam("params", paramsData) + } + + // Signature + req.SetHeaders(y.SignatureHeader(url, method, paramsData, isBool(isFamily...))) + + var erron RespErr + req.SetError(&erron) + + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + + if strings.Contains(res.String(), "userSessionBO is null") { + if err = y.refreshSession(); err != nil { + return nil, err + } + return y.request(url, method, callback, params, resp, isFamily...) + } + + // if erron.ErrorCode == "InvalidSessionKey" || erron.Code == "InvalidSessionKey" { + if strings.Contains(res.String(), "InvalidSessionKey") { + if err = y.refreshSession(); err != nil { + return nil, err + } + return y.request(url, method, callback, params, resp, isFamily...) + } + + // 处理错误 + if erron.HasError() { + return nil, &erron + } + return res.Body(), nil +} + +func (y *Cloud189PC) get(url string, callback base.ReqCallback, resp interface{}, isFamily ...bool) ([]byte, error) { + return y.request(url, http.MethodGet, callback, nil, resp, isFamily...) +} + +func (y *Cloud189PC) post(url string, callback base.ReqCallback, resp interface{}, isFamily ...bool) ([]byte, error) { + return y.request(url, http.MethodPost, callback, nil, resp, isFamily...) +} + +func (y *Cloud189PC) put(ctx context.Context, url string, headers map[string]string, sign bool, file io.Reader, isFamily bool) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, file) + if err != nil { + return nil, err + } + + query := req.URL.Query() + for key, value := range clientSuffix() { + query.Add(key, value) + } + req.URL.RawQuery = query.Encode() + + for key, value := range headers { + req.Header.Add(key, value) + } + + if sign { + for key, value := range y.SignatureHeader(url, http.MethodPut, "", isFamily) { + req.Header.Add(key, value) + } + } + + resp, err := base.HttpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var erron RespErr + jsoniter.Unmarshal(body, &erron) + xml.Unmarshal(body, &erron) + if erron.HasError() { + return nil, &erron + } + if resp.StatusCode != http.StatusOK { + return nil, errors.Errorf("put fail,err:%s", string(body)) + } + return body, nil +} +func (y *Cloud189PC) getFiles(ctx context.Context, fileId string, isFamily bool) ([]model.Obj, error) { + fullUrl := API_URL + if isFamily { + fullUrl += "/family/file" + } + fullUrl += "/listFiles.action" + + res := make([]model.Obj, 0, 130) + for pageNum := 1; ; pageNum++ { + var resp Cloud189FilesResp + _, err := y.get(fullUrl, func(r *resty.Request) { + r.SetContext(ctx) + r.SetQueryParams(map[string]string{ + "folderId": fileId, + "fileType": "0", + "mediaAttr": "0", + "iconOption": "5", + "pageNum": fmt.Sprint(pageNum), + "pageSize": "130", + }) + if isFamily { + r.SetQueryParams(map[string]string{ + "familyId": y.FamilyID, + "orderBy": toFamilyOrderBy(y.OrderBy), + "descending": toDesc(y.OrderDirection), + }) + } else { + r.SetQueryParams(map[string]string{ + "recursive": "0", + "orderBy": y.OrderBy, + "descending": toDesc(y.OrderDirection), + }) + } + }, &resp, isFamily) + if err != nil { + return nil, err + } + // 获取完毕跳出 + if resp.FileListAO.Count == 0 { + break + } + + for i := 0; i < len(resp.FileListAO.FolderList); i++ { + res = append(res, &resp.FileListAO.FolderList[i]) + } + for i := 0; i < len(resp.FileListAO.FileList); i++ { + res = append(res, &resp.FileListAO.FileList[i]) + } + } + return res, nil +} + +func (y *Cloud189PC) login() (err error) { + // 初始化登陆所需参数 + if y.loginParam == nil { + if err = y.initLoginParam(); err != nil { + // 验证码也通过错误返回 + return err + } + } + defer func() { + // 销毁验证码 + y.VCode = "" + // 销毁登陆参数 + y.loginParam = nil + // 遇到错误,重新加载登陆参数(刷新验证码) + if err != nil && y.NoUseOcr { + if err1 := y.initLoginParam(); err1 != nil { + err = fmt.Errorf("err1: %s \nerr2: %s", err, err1) + } + } + }() + + param := y.loginParam + var loginresp LoginResp + _, err = y.client.R(). + ForceContentType("application/json;charset=UTF-8").SetResult(&loginresp). + SetHeaders(map[string]string{ + "REQID": param.ReqId, + "lt": param.Lt, + }). + SetFormData(map[string]string{ + "appKey": APP_ID, + "accountType": ACCOUNT_TYPE, + "userName": param.RsaUsername, + "password": param.RsaPassword, + "validateCode": y.VCode, + "captchaToken": param.CaptchaToken, + "returnUrl": RETURN_URL, + // "mailSuffix": "@189.cn", + "dynamicCheck": "FALSE", + "clientType": CLIENT_TYPE, + "cb_SaveName": "1", + "isOauth2": "false", + "state": "", + "paramId": param.ParamId, + }). + Post(AUTH_URL + "/api/logbox/oauth2/loginSubmit.do") + if err != nil { + return err + } + if loginresp.ToUrl == "" { + return fmt.Errorf("login failed,No toUrl obtained, msg: %s", loginresp.Msg) + } + + // 获取Session + var erron RespErr + var tokenInfo AppSessionResp + _, err = y.client.R(). + SetResult(&tokenInfo).SetError(&erron). + SetQueryParams(clientSuffix()). + SetQueryParam("redirectURL", url.QueryEscape(loginresp.ToUrl)). + Post(API_URL + "/getSessionForPC.action") + if err != nil { + return + } + + if erron.HasError() { + return &erron + } + if tokenInfo.ResCode != 0 { + err = fmt.Errorf(tokenInfo.ResMessage) + return + } + y.tokenInfo = &tokenInfo + return +} + +/* 初始化登陆需要的参数 +* 如果遇到验证码返回错误 + */ +func (y *Cloud189PC) initLoginParam() error { + // 清除cookie + jar, _ := cookiejar.New(nil) + y.client.SetCookieJar(jar) + + res, err := y.client.R(). + SetQueryParams(map[string]string{ + "appId": APP_ID, + "clientType": CLIENT_TYPE, + "returnURL": RETURN_URL, + "timeStamp": fmt.Sprint(timestamp()), + }). + Get(WEB_URL + "/api/portal/unifyLoginForPC.action") + if err != nil { + return err + } + + param := LoginParam{ + CaptchaToken: regexp.MustCompile(`'captchaToken' value='(.+?)'`).FindStringSubmatch(res.String())[1], + Lt: regexp.MustCompile(`lt = "(.+?)"`).FindStringSubmatch(res.String())[1], + ParamId: regexp.MustCompile(`paramId = "(.+?)"`).FindStringSubmatch(res.String())[1], + ReqId: regexp.MustCompile(`reqId = "(.+?)"`).FindStringSubmatch(res.String())[1], + // jRsaKey: regexp.MustCompile(`"j_rsaKey" value="(.+?)"`).FindStringSubmatch(res.String())[1], + } + + // 获取rsa公钥 + var encryptConf EncryptConfResp + _, err = y.client.R(). + ForceContentType("application/json;charset=UTF-8").SetResult(&encryptConf). + SetFormData(map[string]string{"appId": APP_ID}). + Post(AUTH_URL + "/api/logbox/config/encryptConf.do") + if err != nil { + return err + } + + param.jRsaKey = fmt.Sprintf("-----BEGIN PUBLIC KEY-----\n%s\n-----END PUBLIC KEY-----", encryptConf.Data.PubKey) + param.RsaUsername = encryptConf.Data.Pre + RsaEncrypt(param.jRsaKey, y.Username) + param.RsaPassword = encryptConf.Data.Pre + RsaEncrypt(param.jRsaKey, y.Password) + y.loginParam = ¶m + + // 判断是否需要验证码 + resp, err := y.client.R(). + SetHeader("REQID", param.ReqId). + SetFormData(map[string]string{ + "appKey": APP_ID, + "accountType": ACCOUNT_TYPE, + "userName": param.RsaUsername, + }).Post(AUTH_URL + "/api/logbox/oauth2/needcaptcha.do") + if err != nil { + return err + } + if resp.String() == "0" { + return nil + } + + // 拉取验证码 + imgRes, err := y.client.R(). + SetQueryParams(map[string]string{ + "token": param.CaptchaToken, + "REQID": param.ReqId, + "rnd": fmt.Sprint(timestamp()), + }). + Get(AUTH_URL + "/api/logbox/oauth2/picCaptcha.do") + if err != nil { + return fmt.Errorf("failed to obtain verification code") + } + if imgRes.Size() > 20 { + if setting.GetStr(conf.OcrApi) != "" && !y.NoUseOcr { + vRes, err := base.RestyClient.R(). + SetMultipartField("image", "validateCode.png", "image/png", bytes.NewReader(imgRes.Body())). + Post(setting.GetStr(conf.OcrApi)) + if err != nil { + return err + } + if jsoniter.Get(vRes.Body(), "status").ToInt() == 200 { + y.VCode = jsoniter.Get(vRes.Body(), "result").ToString() + return nil + } + } + + // 返回验证码图片给前端 + return fmt.Errorf(`need img validate code: `, base64.StdEncoding.EncodeToString(imgRes.Body())) + } + return nil +} + +// 刷新会话 +func (y *Cloud189PC) refreshSession() (err error) { + var erron RespErr + var userSessionResp UserSessionResp + _, err = y.client.R(). + SetResult(&userSessionResp).SetError(&erron). + SetQueryParams(clientSuffix()). + SetQueryParams(map[string]string{ + "appId": APP_ID, + "accessToken": y.tokenInfo.AccessToken, + }). + SetHeader("X-Request-ID", uuid.NewString()). + Get(API_URL + "/getSessionForPC.action") + if err != nil { + return err + } + + // 错误影响正常访问,下线该储存 + defer func() { + if err != nil { + y.GetStorage().SetStatus(fmt.Sprintf("%+v", err.Error())) + op.MustSaveDriverStorage(y) + } + }() + + if erron.HasError() { + if erron.ResCode == "UserInvalidOpenToken" { + if err = y.login(); err != nil { + return err + } + } + return &erron + } + y.tokenInfo.UserSessionResp = userSessionResp + return +} + +// 普通上传 +// 无法上传大小为0的文件 +func (y *Cloud189PC) StreamUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { + var sliceSize = partSize(file.GetSize()) + count := int(math.Ceil(float64(file.GetSize()) / float64(sliceSize))) + lastPartSize := file.GetSize() % sliceSize + if file.GetSize() > 0 && lastPartSize == 0 { + lastPartSize = sliceSize + } + + params := Params{ + "parentFolderId": dstDir.GetID(), + "fileName": url.QueryEscape(file.GetName()), + "fileSize": fmt.Sprint(file.GetSize()), + "sliceSize": fmt.Sprint(sliceSize), + "lazyCheck": "1", + } + + fullUrl := UPLOAD_URL + if isFamily { + params.Set("familyId", y.FamilyID) + fullUrl += "/family" + } else { + //params.Set("extend", `{"opScene":"1","relativepath":"","rootfolderid":""}`) + fullUrl += "/person" + } + + // 初始化上传 + var initMultiUpload InitMultiUploadResp + _, err := y.request(fullUrl+"/initMultiUpload", http.MethodGet, func(req *resty.Request) { + req.SetContext(ctx) + }, params, &initMultiUpload, isFamily) + if err != nil { + return nil, err + } + + threadG, upCtx := errgroup.NewGroupWithContext(ctx, y.uploadThread, + retry.Attempts(3), + retry.Delay(time.Second), + retry.DelayType(retry.BackOffDelay)) + + fileMd5 := md5.New() + silceMd5 := md5.New() + silceMd5Hexs := make([]string, 0, count) + + for i := 1; i <= count; i++ { + if utils.IsCanceled(upCtx) { + break + } + + byteData := make([]byte, sliceSize) + if i == count { + byteData = byteData[:lastPartSize] + } + + // 读取块 + silceMd5.Reset() + if _, err := io.ReadFull(io.TeeReader(file, io.MultiWriter(fileMd5, silceMd5)), byteData); err != io.EOF && err != nil { + return nil, err + } + + // 计算块md5并进行hex和base64编码 + md5Bytes := silceMd5.Sum(nil) + silceMd5Hexs = append(silceMd5Hexs, strings.ToUpper(hex.EncodeToString(md5Bytes))) + partInfo := fmt.Sprintf("%d-%s", i, base64.StdEncoding.EncodeToString(md5Bytes)) + + threadG.Go(func(ctx context.Context) error { + uploadUrls, err := y.GetMultiUploadUrls(ctx, isFamily, initMultiUpload.Data.UploadFileID, partInfo) + if err != nil { + return err + } + + // step.4 上传切片 + uploadUrl := uploadUrls[0] + _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, bytes.NewReader(byteData), isFamily) + if err != nil { + return err + } + up(float64(threadG.Success()) * 100 / float64(count)) + return nil + }) + } + if err = threadG.Wait(); err != nil { + return nil, err + } + + fileMd5Hex := strings.ToUpper(hex.EncodeToString(fileMd5.Sum(nil))) + sliceMd5Hex := fileMd5Hex + if file.GetSize() > sliceSize { + sliceMd5Hex = strings.ToUpper(utils.GetMD5EncodeStr(strings.Join(silceMd5Hexs, "\n"))) + } + + // 提交上传 + var resp CommitMultiUploadFileResp + _, err = y.request(fullUrl+"/commitMultiUploadFile", http.MethodGet, + func(req *resty.Request) { + req.SetContext(ctx) + }, Params{ + "uploadFileId": initMultiUpload.Data.UploadFileID, + "fileMd5": fileMd5Hex, + "sliceMd5": sliceMd5Hex, + "lazyCheck": "1", + "isLog": "0", + "opertype": IF(overwrite, "3", "1"), + }, &resp, isFamily) + if err != nil { + return nil, err + } + return resp.toFile(), nil +} + +func (y *Cloud189PC) RapidUpload(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, isFamily bool, overwrite bool) (model.Obj, error) { + fileMd5 := stream.GetHash().GetHash(utils.MD5) + if len(fileMd5) < utils.MD5.Width { + return nil, errors.New("invalid hash") + } + + uploadInfo, err := y.OldUploadCreate(ctx, dstDir.GetID(), fileMd5, stream.GetName(), fmt.Sprint(stream.GetSize()), isFamily) + if err != nil { + return nil, err + } + + if uploadInfo.FileDataExists != 1 { + return nil, errors.New("rapid upload fail") + } + + return y.OldUploadCommit(ctx, uploadInfo.FileCommitUrl, uploadInfo.UploadFileId, isFamily, overwrite) +} + +// 快传 +func (y *Cloud189PC) FastUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { + tempFile, err := file.CacheFullInTempFile() + if err != nil { + return nil, err + } + + var sliceSize = partSize(file.GetSize()) + count := int(math.Ceil(float64(file.GetSize()) / float64(sliceSize))) + lastSliceSize := file.GetSize() % sliceSize + if file.GetSize() > 0 && lastSliceSize == 0 { + lastSliceSize = sliceSize + } + + //step.1 优先计算所需信息 + byteSize := sliceSize + fileMd5 := md5.New() + silceMd5 := md5.New() + silceMd5Hexs := make([]string, 0, count) + partInfos := make([]string, 0, count) + for i := 1; i <= count; i++ { + if utils.IsCanceled(ctx) { + return nil, ctx.Err() + } + + if i == count { + byteSize = lastSliceSize + } + + silceMd5.Reset() + if _, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5, silceMd5), tempFile, byteSize); err != nil && err != io.EOF { + return nil, err + } + md5Byte := silceMd5.Sum(nil) + silceMd5Hexs = append(silceMd5Hexs, strings.ToUpper(hex.EncodeToString(md5Byte))) + partInfos = append(partInfos, fmt.Sprint(i, "-", base64.StdEncoding.EncodeToString(md5Byte))) + } + + fileMd5Hex := strings.ToUpper(hex.EncodeToString(fileMd5.Sum(nil))) + sliceMd5Hex := fileMd5Hex + if file.GetSize() > sliceSize { + sliceMd5Hex = strings.ToUpper(utils.GetMD5EncodeStr(strings.Join(silceMd5Hexs, "\n"))) + } + + fullUrl := UPLOAD_URL + if isFamily { + fullUrl += "/family" + } else { + //params.Set("extend", `{"opScene":"1","relativepath":"","rootfolderid":""}`) + fullUrl += "/person" + } + + // 尝试恢复进度 + uploadProgress, ok := base.GetUploadProgress[*UploadProgress](y, y.tokenInfo.SessionKey, fileMd5Hex) + if !ok { + //step.2 预上传 + params := Params{ + "parentFolderId": dstDir.GetID(), + "fileName": url.QueryEscape(file.GetName()), + "fileSize": fmt.Sprint(file.GetSize()), + "fileMd5": fileMd5Hex, + "sliceSize": fmt.Sprint(sliceSize), + "sliceMd5": sliceMd5Hex, + } + if isFamily { + params.Set("familyId", y.FamilyID) + } + var uploadInfo InitMultiUploadResp + _, err = y.request(fullUrl+"/initMultiUpload", http.MethodGet, func(req *resty.Request) { + req.SetContext(ctx) + }, params, &uploadInfo, isFamily) + if err != nil { + return nil, err + } + uploadProgress = &UploadProgress{ + UploadInfo: uploadInfo, + UploadParts: partInfos, + } + } + + uploadInfo := uploadProgress.UploadInfo.Data + // 网盘中不存在该文件,开始上传 + if uploadInfo.FileDataExists != 1 { + threadG, upCtx := errgroup.NewGroupWithContext(ctx, y.uploadThread, + retry.Attempts(3), + retry.Delay(time.Second), + retry.DelayType(retry.BackOffDelay)) + for i, uploadPart := range uploadProgress.UploadParts { + if utils.IsCanceled(upCtx) { + break + } + + i, uploadPart := i, uploadPart + threadG.Go(func(ctx context.Context) error { + // step.3 获取上传链接 + uploadUrls, err := y.GetMultiUploadUrls(ctx, isFamily, uploadInfo.UploadFileID, uploadPart) + if err != nil { + return err + } + uploadUrl := uploadUrls[0] + + byteSize, offset := sliceSize, int64(uploadUrl.PartNumber-1)*sliceSize + if uploadUrl.PartNumber == count { + byteSize = lastSliceSize + } + + // step.4 上传切片 + _, err = y.put(ctx, uploadUrl.RequestURL, uploadUrl.Headers, false, io.NewSectionReader(tempFile, offset, byteSize), isFamily) + if err != nil { + return err + } + + up(float64(threadG.Success()) * 100 / float64(len(uploadUrls))) + uploadProgress.UploadParts[i] = "" + return nil + }) + } + if err = threadG.Wait(); err != nil { + if errors.Is(err, context.Canceled) { + uploadProgress.UploadParts = utils.SliceFilter(uploadProgress.UploadParts, func(s string) bool { return s != "" }) + base.SaveUploadProgress(y, uploadProgress, y.tokenInfo.SessionKey, fileMd5Hex) + } + return nil, err + } + } + + // step.5 提交 + var resp CommitMultiUploadFileResp + _, err = y.request(fullUrl+"/commitMultiUploadFile", http.MethodGet, + func(req *resty.Request) { + req.SetContext(ctx) + }, Params{ + "uploadFileId": uploadInfo.UploadFileID, + "isLog": "0", + "opertype": IF(overwrite, "3", "1"), + }, &resp, isFamily) + if err != nil { + return nil, err + } + return resp.toFile(), nil +} + +// 获取上传切片信息 +// 对http body有大小限制,分片信息太多会出错 +func (y *Cloud189PC) GetMultiUploadUrls(ctx context.Context, isFamily bool, uploadFileId string, partInfo ...string) ([]UploadUrlInfo, error) { + fullUrl := UPLOAD_URL + if isFamily { + fullUrl += "/family" + } else { + fullUrl += "/person" + } + + var uploadUrlsResp UploadUrlsResp + _, err := y.request(fullUrl+"/getMultiUploadUrls", http.MethodGet, + func(req *resty.Request) { + req.SetContext(ctx) + }, Params{ + "uploadFileId": uploadFileId, + "partInfo": strings.Join(partInfo, ","), + }, &uploadUrlsResp, isFamily) + if err != nil { + return nil, err + } + uploadUrls := uploadUrlsResp.Data + + if len(uploadUrls) != len(partInfo) { + return nil, fmt.Errorf("uploadUrls get error, due to get length %d, real length %d", len(partInfo), len(uploadUrls)) + } + + uploadUrlInfos := make([]UploadUrlInfo, 0, len(uploadUrls)) + for k, uploadUrl := range uploadUrls { + partNumber, err := strconv.Atoi(strings.TrimPrefix(k, "partNumber_")) + if err != nil { + return nil, err + } + uploadUrlInfos = append(uploadUrlInfos, UploadUrlInfo{ + PartNumber: partNumber, + Headers: ParseHttpHeader(uploadUrl.RequestHeader), + UploadUrlsData: uploadUrl, + }) + } + sort.Slice(uploadUrlInfos, func(i, j int) bool { + return uploadUrlInfos[i].PartNumber < uploadUrlInfos[j].PartNumber + }) + return uploadUrlInfos, nil +} + +// 旧版本上传,家庭云不支持覆盖 +func (y *Cloud189PC) OldUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress, isFamily bool, overwrite bool) (model.Obj, error) { + tempFile, err := file.CacheFullInTempFile() + if err != nil { + return nil, err + } + fileMd5, err := utils.HashFile(utils.MD5, tempFile) + if err != nil { + return nil, err + } + + // 创建上传会话 + uploadInfo, err := y.OldUploadCreate(ctx, dstDir.GetID(), fileMd5, file.GetName(), fmt.Sprint(file.GetSize()), isFamily) + if err != nil { + return nil, err + } + + // 网盘中不存在该文件,开始上传 + status := GetUploadFileStatusResp{CreateUploadFileResp: *uploadInfo} + for status.GetSize() < file.GetSize() && status.FileDataExists != 1 { + if utils.IsCanceled(ctx) { + return nil, ctx.Err() + } + + header := map[string]string{ + "ResumePolicy": "1", + "Expect": "100-continue", + } + + if isFamily { + header["FamilyId"] = fmt.Sprint(y.FamilyID) + header["UploadFileId"] = fmt.Sprint(status.UploadFileId) + } else { + header["Edrive-UploadFileId"] = fmt.Sprint(status.UploadFileId) + } + + _, err := y.put(ctx, status.FileUploadUrl, header, true, io.NopCloser(tempFile), isFamily) + if err, ok := err.(*RespErr); ok && err.Code != "InputStreamReadError" { + return nil, err + } + + // 获取断点状态 + fullUrl := API_URL + "/getUploadFileStatus.action" + if y.isFamily() { + fullUrl = API_URL + "/family/file/getFamilyFileStatus.action" + } + _, err = y.get(fullUrl, func(req *resty.Request) { + req.SetContext(ctx).SetQueryParams(map[string]string{ + "uploadFileId": fmt.Sprint(status.UploadFileId), + "resumePolicy": "1", + }) + if isFamily { + req.SetQueryParam("familyId", fmt.Sprint(y.FamilyID)) + } + }, &status, isFamily) + if err != nil { + return nil, err + } + if _, err := tempFile.Seek(status.GetSize(), io.SeekStart); err != nil { + return nil, err + } + up(float64(status.GetSize()) / float64(file.GetSize()) * 100) + } + + return y.OldUploadCommit(ctx, status.FileCommitUrl, status.UploadFileId, isFamily, overwrite) +} + +// 创建上传会话 +func (y *Cloud189PC) OldUploadCreate(ctx context.Context, parentID string, fileMd5, fileName, fileSize string, isFamily bool) (*CreateUploadFileResp, error) { + var uploadInfo CreateUploadFileResp + + fullUrl := API_URL + "/createUploadFile.action" + if isFamily { + fullUrl = API_URL + "/family/file/createFamilyFile.action" + } + _, err := y.post(fullUrl, func(req *resty.Request) { + req.SetContext(ctx) + if isFamily { + req.SetQueryParams(map[string]string{ + "familyId": y.FamilyID, + "parentId": parentID, + "fileMd5": fileMd5, + "fileName": fileName, + "fileSize": fileSize, + "resumePolicy": "1", + }) + } else { + req.SetFormData(map[string]string{ + "parentFolderId": parentID, + "fileName": fileName, + "size": fileSize, + "md5": fileMd5, + "opertype": "3", + "flag": "1", + "resumePolicy": "1", + "isLog": "0", + }) + } + }, &uploadInfo, isFamily) + + if err != nil { + return nil, err + } + return &uploadInfo, nil +} + +// 提交上传文件 +func (y *Cloud189PC) OldUploadCommit(ctx context.Context, fileCommitUrl string, uploadFileID int64, isFamily bool, overwrite bool) (model.Obj, error) { + var resp OldCommitUploadFileResp + _, err := y.post(fileCommitUrl, func(req *resty.Request) { + req.SetContext(ctx) + if isFamily { + req.SetHeaders(map[string]string{ + "ResumePolicy": "1", + "UploadFileId": fmt.Sprint(uploadFileID), + "FamilyId": fmt.Sprint(y.FamilyID), + }) + } else { + req.SetFormData(map[string]string{ + "opertype": IF(overwrite, "3", "1"), + "resumePolicy": "1", + "uploadFileId": fmt.Sprint(uploadFileID), + "isLog": "0", + }) + } + }, &resp, isFamily) + if err != nil { + return nil, err + } + return resp.toFile(), nil +} + +func (y *Cloud189PC) isFamily() bool { + return y.Type == "family" +} + +func (y *Cloud189PC) isLogin() bool { + if y.tokenInfo == nil { + return false + } + _, err := y.get(API_URL+"/getUserInfo.action", nil, nil) + return err == nil +} + +// 创建家庭云中转文件夹 +func (y *Cloud189PC) createFamilyTransferFolder(count int) (*ring.Ring, error) { + folders := ring.New(count) + var rootFolder Cloud189Folder + _, err := y.post(API_URL+"/family/file/createFolder.action", func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "folderName": "FamilyTransferFolder", + "familyId": y.FamilyID, + }) + }, &rootFolder, true) + if err != nil { + return nil, err + } + + folderCount := 0 + + // 获取已有目录 + files, err := y.getFiles(context.TODO(), rootFolder.GetID(), true) + if err != nil { + return nil, err + } + for _, file := range files { + if folder, ok := file.(*Cloud189Folder); ok { + folders.Value = folder + folders = folders.Next() + folderCount++ + } + } + + // 创建新的目录 + for folderCount < count { + var newFolder Cloud189Folder + _, err := y.post(API_URL+"/family/file/createFolder.action", func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "folderName": uuid.NewString(), + "familyId": y.FamilyID, + "parentId": rootFolder.GetID(), + }) + }, &newFolder, true) + if err != nil { + return nil, err + } + folders.Value = &newFolder + folders = folders.Next() + folderCount++ + } + return folders, nil +} + +// 清理中转文件夹 +func (y *Cloud189PC) cleanFamilyTransfer(ctx context.Context) error { + var tasks []BatchTaskInfo + r := y.familyTransferFolder + for p := r.Next(); p != r; p = p.Next() { + folder := p.Value.(*Cloud189Folder) + + files, err := y.getFiles(ctx, folder.GetID(), true) + if err != nil { + return err + } + for _, file := range files { + tasks = append(tasks, BatchTaskInfo{ + FileId: file.GetID(), + FileName: file.GetName(), + IsFolder: BoolToNumber(file.IsDir()), + }) + } + } + + if len(tasks) > 0 { + // 删除 + resp, err := y.CreateBatchTask("DELETE", y.FamilyID, "", nil, tasks...) + if err != nil { + return err + } + err = y.WaitBatchTask("DELETE", resp.TaskID, time.Second) + if err != nil { + return err + } + // 永久删除 + resp, err = y.CreateBatchTask("CLEAR_RECYCLE", y.FamilyID, "", nil, tasks...) + if err != nil { + return err + } + err = y.WaitBatchTask("CLEAR_RECYCLE", resp.TaskID, time.Second) + return err + } + return nil +} + +// 获取家庭云所有用户信息 +func (y *Cloud189PC) getFamilyInfoList() ([]FamilyInfoResp, error) { + var resp FamilyInfoListResp + _, err := y.get(API_URL+"/family/manage/getFamilyList.action", nil, &resp, true) + if err != nil { + return nil, err + } + return resp.FamilyInfoResp, nil +} + +// 抽取家庭云ID +func (y *Cloud189PC) getFamilyID() (string, error) { + infos, err := y.getFamilyInfoList() + if err != nil { + return "", err + } + if len(infos) == 0 { + return "", fmt.Errorf("cannot get automatically,please input family_id") + } + for _, info := range infos { + if strings.Contains(y.tokenInfo.LoginName, info.RemarkName) { + return fmt.Sprint(info.FamilyID), nil + } + } + return fmt.Sprint(infos[0].FamilyID), nil +} + +// 保存家庭云中的文件到个人云 +func (y *Cloud189PC) SaveFamilyFileToPersonCloud(ctx context.Context, familyId string, srcObj, dstDir model.Obj, overwrite bool) error { + // _, err := y.post(API_URL+"/family/file/saveFileToMember.action", func(req *resty.Request) { + // req.SetQueryParams(map[string]string{ + // "channelId": "home", + // "familyId": familyId, + // "destParentId": destParentId, + // "fileIdList": familyFileId, + // }) + // }, nil) + // return err + + task := BatchTaskInfo{ + FileId: srcObj.GetID(), + FileName: srcObj.GetName(), + IsFolder: BoolToNumber(srcObj.IsDir()), + } + resp, err := y.CreateBatchTask("COPY", familyId, dstDir.GetID(), map[string]string{ + "groupId": "null", + "copyType": "2", + "shareId": "null", + }, task) + if err != nil { + return err + } + + for { + state, err := y.CheckBatchTask("COPY", resp.TaskID) + if err != nil { + return err + } + switch state.TaskStatus { + case 2: + task.DealWay = IF(overwrite, 3, 2) + // 冲突时覆盖文件 + if err := y.ManageBatchTask("COPY", resp.TaskID, dstDir.GetID(), task); err != nil { + return err + } + case 4: + return nil + } + time.Sleep(time.Millisecond * 400) + } +} + +func (y *Cloud189PC) CreateBatchTask(aType string, familyID string, targetFolderId string, other map[string]string, taskInfos ...BatchTaskInfo) (*CreateBatchTaskResp, error) { + var resp CreateBatchTaskResp + _, err := y.post(API_URL+"/batch/createBatchTask.action", func(req *resty.Request) { + req.SetFormData(map[string]string{ + "type": aType, + "taskInfos": MustString(utils.Json.MarshalToString(taskInfos)), + }) + if targetFolderId != "" { + req.SetFormData(map[string]string{"targetFolderId": targetFolderId}) + } + if familyID != "" { + req.SetFormData(map[string]string{"familyId": familyID}) + } + req.SetFormData(other) + }, &resp, familyID != "") + if err != nil { + return nil, err + } + return &resp, nil +} + +// 检测任务状态 +func (y *Cloud189PC) CheckBatchTask(aType string, taskID string) (*BatchTaskStateResp, error) { + var resp BatchTaskStateResp + _, err := y.post(API_URL+"/batch/checkBatchTask.action", func(req *resty.Request) { + req.SetFormData(map[string]string{ + "type": aType, + "taskId": taskID, + }) + }, &resp) + if err != nil { + return nil, err + } + return &resp, nil +} + +// 获取冲突的任务信息 +func (y *Cloud189PC) GetConflictTaskInfo(aType string, taskID string) (*BatchTaskConflictTaskInfoResp, error) { + var resp BatchTaskConflictTaskInfoResp + _, err := y.post(API_URL+"/batch/getConflictTaskInfo.action", func(req *resty.Request) { + req.SetFormData(map[string]string{ + "type": aType, + "taskId": taskID, + }) + }, &resp) + if err != nil { + return nil, err + } + return &resp, nil +} + +// 处理冲突 +func (y *Cloud189PC) ManageBatchTask(aType string, taskID string, targetFolderId string, taskInfos ...BatchTaskInfo) error { + _, err := y.post(API_URL+"/batch/manageBatchTask.action", func(req *resty.Request) { + req.SetFormData(map[string]string{ + "targetFolderId": targetFolderId, + "type": aType, + "taskId": taskID, + "taskInfos": MustString(utils.Json.MarshalToString(taskInfos)), + }) + }, nil) + return err +} + +var ErrIsConflict = errors.New("there is a conflict with the target object") + +// 等待任务完成 +func (y *Cloud189PC) WaitBatchTask(aType string, taskID string, t time.Duration) error { + for { + state, err := y.CheckBatchTask(aType, taskID) + if err != nil { + return err + } + switch state.TaskStatus { + case 2: + return ErrIsConflict + case 4: + return nil + } + time.Sleep(t) + } +} diff --git a/drivers/alias/driver.go b/drivers/alias/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..1b439a2c9d98a86e97fe6be258d38439646f94cb --- /dev/null +++ b/drivers/alias/driver.go @@ -0,0 +1,141 @@ +package alias + +import ( + "context" + "errors" + "strings" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" +) + +type Alias struct { + model.Storage + Addition + pathMap map[string][]string + autoFlatten bool + oneKey string +} + +func (d *Alias) Config() driver.Config { + return config +} + +func (d *Alias) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Alias) Init(ctx context.Context) error { + if d.Paths == "" { + return errors.New("paths is required") + } + d.pathMap = make(map[string][]string) + for _, path := range strings.Split(d.Paths, "\n") { + path = strings.TrimSpace(path) + if path == "" { + continue + } + k, v := getPair(path) + d.pathMap[k] = append(d.pathMap[k], v) + } + if len(d.pathMap) == 1 { + for k := range d.pathMap { + d.oneKey = k + } + d.autoFlatten = true + } else { + d.oneKey = "" + d.autoFlatten = false + } + return nil +} + +func (d *Alias) Drop(ctx context.Context) error { + d.pathMap = nil + return nil +} + +func (d *Alias) Get(ctx context.Context, path string) (model.Obj, error) { + if utils.PathEqual(path, "/") { + return &model.Object{ + Name: "Root", + IsFolder: true, + Path: "/", + }, nil + } + root, sub := d.getRootAndPath(path) + dsts, ok := d.pathMap[root] + if !ok { + return nil, errs.ObjectNotFound + } + for _, dst := range dsts { + obj, err := d.get(ctx, path, dst, sub) + if err == nil { + return obj, nil + } + } + return nil, errs.ObjectNotFound +} + +func (d *Alias) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + path := dir.GetPath() + if utils.PathEqual(path, "/") && !d.autoFlatten { + return d.listRoot(), nil + } + root, sub := d.getRootAndPath(path) + dsts, ok := d.pathMap[root] + if !ok { + return nil, errs.ObjectNotFound + } + var objs []model.Obj + fsArgs := &fs.ListArgs{NoLog: true, Refresh: args.Refresh} + for _, dst := range dsts { + tmp, err := d.list(ctx, dst, sub, fsArgs) + if err == nil { + objs = append(objs, tmp...) + } + } + return objs, nil +} + +func (d *Alias) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + root, sub := d.getRootAndPath(file.GetPath()) + dsts, ok := d.pathMap[root] + if !ok { + return nil, errs.ObjectNotFound + } + for _, dst := range dsts { + link, err := d.link(ctx, dst, sub, args) + if err == nil { + return link, nil + } + } + return nil, errs.ObjectNotFound +} + +func (d *Alias) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + reqPath, err := d.getReqPath(ctx, srcObj) + if err == nil { + return fs.Rename(ctx, *reqPath, newName) + } + if errs.IsNotImplement(err) { + return errors.New("same-name files cannot be Rename") + } + return err +} + +func (d *Alias) Remove(ctx context.Context, obj model.Obj) error { + reqPath, err := d.getReqPath(ctx, obj) + if err == nil { + return fs.Remove(ctx, *reqPath) + } + if errs.IsNotImplement(err) { + return errors.New("same-name files cannot be Delete") + } + return err +} + +var _ driver.Driver = (*Alias)(nil) diff --git a/drivers/alias/meta.go b/drivers/alias/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..45b885753d02e7c84cedb3772e94dc6f3ec7fd48 --- /dev/null +++ b/drivers/alias/meta.go @@ -0,0 +1,33 @@ +package alias + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + // Usually one of two + // driver.RootPath + // define other + Paths string `json:"paths" required:"true" type:"text"` + ProtectSameName bool `json:"protect_same_name" default:"true" required:"false" help:"Protects same-name files from Delete or Rename"` +} + +var config = driver.Config{ + Name: "Alias", + LocalSort: true, + NoCache: true, + NoUpload: true, + DefaultRoot: "/", + ProxyRangeOption: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Alias{ + Addition: Addition{ + ProtectSameName: true, + }, + } + }) +} diff --git a/drivers/alias/types.go b/drivers/alias/types.go new file mode 100644 index 0000000000000000000000000000000000000000..e560393da669fc613b4f1e3b730f3bf4b3919766 --- /dev/null +++ b/drivers/alias/types.go @@ -0,0 +1 @@ +package alias diff --git a/drivers/alias/util.go b/drivers/alias/util.go new file mode 100644 index 0000000000000000000000000000000000000000..c0e9081b0fcebb840dab6add7bdf33b08c977e2a --- /dev/null +++ b/drivers/alias/util.go @@ -0,0 +1,151 @@ +package alias + +import ( + "context" + "fmt" + stdpath "path" + "strings" + + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/sign" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" +) + +func (d *Alias) listRoot() []model.Obj { + var objs []model.Obj + for k := range d.pathMap { + obj := model.Object{ + Name: k, + IsFolder: true, + Modified: d.Modified, + } + objs = append(objs, &obj) + } + return objs +} + +// do others that not defined in Driver interface +func getPair(path string) (string, string) { + //path = strings.TrimSpace(path) + if strings.Contains(path, ":") { + pair := strings.SplitN(path, ":", 2) + if !strings.Contains(pair[0], "/") { + return pair[0], pair[1] + } + } + return stdpath.Base(path), path +} + +func (d *Alias) getRootAndPath(path string) (string, string) { + if d.autoFlatten { + return d.oneKey, path + } + path = strings.TrimPrefix(path, "/") + parts := strings.SplitN(path, "/", 2) + if len(parts) == 1 { + return parts[0], "" + } + return parts[0], parts[1] +} + +func (d *Alias) get(ctx context.Context, path string, dst, sub string) (model.Obj, error) { + obj, err := fs.Get(ctx, stdpath.Join(dst, sub), &fs.GetArgs{NoLog: true}) + if err != nil { + return nil, err + } + return &model.Object{ + Path: path, + Name: obj.GetName(), + Size: obj.GetSize(), + Modified: obj.ModTime(), + IsFolder: obj.IsDir(), + }, nil +} + +func (d *Alias) list(ctx context.Context, dst, sub string, args *fs.ListArgs) ([]model.Obj, error) { + objs, err := fs.List(ctx, stdpath.Join(dst, sub), args) + // the obj must implement the model.SetPath interface + // return objs, err + if err != nil { + return nil, err + } + return utils.SliceConvert(objs, func(obj model.Obj) (model.Obj, error) { + thumb, ok := model.GetThumb(obj) + objRes := model.Object{ + Name: obj.GetName(), + Size: obj.GetSize(), + Modified: obj.ModTime(), + IsFolder: obj.IsDir(), + } + if !ok { + return &objRes, nil + } + return &model.ObjThumb{ + Object: objRes, + Thumbnail: model.Thumbnail{ + Thumbnail: thumb, + }, + }, nil + }) +} + +func (d *Alias) link(ctx context.Context, dst, sub string, args model.LinkArgs) (*model.Link, error) { + reqPath := stdpath.Join(dst, sub) + storage, err := fs.GetStorage(reqPath, &fs.GetStoragesArgs{}) + if err != nil { + return nil, err + } + _, err = fs.Get(ctx, reqPath, &fs.GetArgs{NoLog: true}) + if err != nil { + return nil, err + } + if common.ShouldProxy(storage, stdpath.Base(sub)) { + link := &model.Link{ + URL: fmt.Sprintf("%s/p%s?sign=%s", + common.GetApiUrl(args.HttpReq), + utils.EncodePath(reqPath, true), + sign.Sign(reqPath)), + } + if args.HttpReq != nil && d.ProxyRange { + link.RangeReadCloser = common.NoProxyRange + } + return link, nil + } + link, _, err := fs.Link(ctx, reqPath, args) + return link, err +} + +func (d *Alias) getReqPath(ctx context.Context, obj model.Obj) (*string, error) { + root, sub := d.getRootAndPath(obj.GetPath()) + if sub == "" { + return nil, errs.NotSupport + } + dsts, ok := d.pathMap[root] + if !ok { + return nil, errs.ObjectNotFound + } + var reqPath *string + for _, dst := range dsts { + path := stdpath.Join(dst, sub) + _, err := fs.Get(ctx, path, &fs.GetArgs{NoLog: true}) + if err != nil { + continue + } + if !d.ProtectSameName { + return &path, nil + } + if ok { + ok = false + } else { + return nil, errs.NotImplement + } + reqPath = &path + } + if reqPath == nil { + return nil, errs.ObjectNotFound + } + return reqPath, nil +} diff --git a/drivers/alist_v2/driver.go b/drivers/alist_v2/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..a48588cd4c44e560d7e759367d6047d04d3cb937 --- /dev/null +++ b/drivers/alist_v2/driver.go @@ -0,0 +1,118 @@ +package alist_v2 + +import ( + "context" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/server/common" +) + +type AListV2 struct { + model.Storage + Addition +} + +func (d *AListV2) Config() driver.Config { + return config +} + +func (d *AListV2) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *AListV2) Init(ctx context.Context) error { + if len(d.Addition.Address) > 0 && string(d.Addition.Address[len(d.Addition.Address)-1]) == "/" { + d.Addition.Address = d.Addition.Address[0 : len(d.Addition.Address)-1] + } + // TODO login / refresh token + //op.MustSaveDriverStorage(d) + return nil +} + +func (d *AListV2) Drop(ctx context.Context) error { + return nil +} + +func (d *AListV2) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + url := d.Address + "/api/public/path" + var resp common.Resp[PathResp] + _, err := base.RestyClient.R(). + SetResult(&resp). + SetHeader("Authorization", d.AccessToken). + SetBody(PathReq{ + PageNum: 0, + PageSize: 0, + Path: dir.GetPath(), + Password: d.Password, + }).Post(url) + if err != nil { + return nil, err + } + var files []model.Obj + for _, f := range resp.Data.Files { + file := model.ObjThumb{ + Object: model.Object{ + Name: f.Name, + Modified: *f.UpdatedAt, + Size: f.Size, + IsFolder: f.Type == 1, + }, + Thumbnail: model.Thumbnail{Thumbnail: f.Thumbnail}, + } + files = append(files, &file) + } + return files, nil +} + +func (d *AListV2) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + url := d.Address + "/api/public/path" + var resp common.Resp[PathResp] + _, err := base.RestyClient.R(). + SetResult(&resp). + SetHeader("Authorization", d.AccessToken). + SetBody(PathReq{ + PageNum: 0, + PageSize: 0, + Path: file.GetPath(), + Password: d.Password, + }).Post(url) + if err != nil { + return nil, err + } + return &model.Link{ + URL: resp.Data.Files[0].Url, + }, nil +} + +func (d *AListV2) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + return errs.NotImplement +} + +func (d *AListV2) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotImplement +} + +func (d *AListV2) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + return errs.NotImplement +} + +func (d *AListV2) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotImplement +} + +func (d *AListV2) Remove(ctx context.Context, obj model.Obj) error { + return errs.NotImplement +} + +func (d *AListV2) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + return errs.NotImplement +} + +//func (d *AList) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*AListV2)(nil) diff --git a/drivers/alist_v2/meta.go b/drivers/alist_v2/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..30dd8de237b018cfbda0623e385744e45a04553d --- /dev/null +++ b/drivers/alist_v2/meta.go @@ -0,0 +1,26 @@ +package alist_v2 + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + Address string `json:"url" required:"true"` + Password string `json:"password"` + AccessToken string `json:"access_token"` +} + +var config = driver.Config{ + Name: "AList V2", + LocalSort: true, + NoUpload: true, + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &AListV2{} + }) +} diff --git a/drivers/alist_v2/types.go b/drivers/alist_v2/types.go new file mode 100644 index 0000000000000000000000000000000000000000..b2317fbf0d94fdafff724ebe8cde11b1a8d57d25 --- /dev/null +++ b/drivers/alist_v2/types.go @@ -0,0 +1,31 @@ +package alist_v2 + +import ( + "time" +) + +type File struct { + Id string `json:"-"` + Name string `json:"name"` + Size int64 `json:"size"` + Type int `json:"type"` + Driver string `json:"driver"` + UpdatedAt *time.Time `json:"updated_at"` + Thumbnail string `json:"thumbnail"` + Url string `json:"url"` + SizeStr string `json:"size_str"` + TimeStr string `json:"time_str"` +} + +type PathResp struct { + Type string `json:"type"` + //Meta Meta `json:"meta"` + Files []File `json:"files"` +} + +type PathReq struct { + PageNum int `json:"page_num"` + PageSize int `json:"page_size"` + Password string `json:"password"` + Path string `json:"path"` +} diff --git a/drivers/alist_v2/util.go b/drivers/alist_v2/util.go new file mode 100644 index 0000000000000000000000000000000000000000..e6682193be31ffbf9c9e86128f6b9a3305404209 --- /dev/null +++ b/drivers/alist_v2/util.go @@ -0,0 +1 @@ +package alist_v2 diff --git a/drivers/alist_v3/driver.go b/drivers/alist_v3/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..d078c5fb421568b3383276d856b101ce097ac0a0 --- /dev/null +++ b/drivers/alist_v3/driver.go @@ -0,0 +1,226 @@ +package alist_v3 + +import ( + "context" + "fmt" + "io" + "net/http" + "path" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +type AListV3 struct { + model.Storage + Addition +} + +func (d *AListV3) Config() driver.Config { + return config +} + +func (d *AListV3) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *AListV3) Init(ctx context.Context) error { + d.Addition.Address = strings.TrimSuffix(d.Addition.Address, "/") + var resp common.Resp[MeResp] + _, err := d.request("/me", http.MethodGet, func(req *resty.Request) { + req.SetResult(&resp) + }) + if err != nil { + return err + } + // if the username is not empty and the username is not the same as the current username, then login again + if d.Username != resp.Data.Username { + err = d.login() + if err != nil { + return err + } + } + // re-get the user info + _, err = d.request("/me", http.MethodGet, func(req *resty.Request) { + req.SetResult(&resp) + }) + if err != nil { + return err + } + if resp.Data.Role == model.GUEST { + url := d.Address + "/api/public/settings" + res, err := base.RestyClient.R().Get(url) + if err != nil { + return err + } + allowMounted := utils.Json.Get(res.Body(), "data", conf.AllowMounted).ToString() == "true" + if !allowMounted { + return fmt.Errorf("the site does not allow mounted") + } + } + return err +} + +func (d *AListV3) Drop(ctx context.Context) error { + return nil +} + +func (d *AListV3) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + var resp common.Resp[FsListResp] + _, err := d.request("/fs/list", http.MethodPost, func(req *resty.Request) { + req.SetResult(&resp).SetBody(ListReq{ + PageReq: model.PageReq{ + Page: 1, + PerPage: 0, + }, + Path: dir.GetPath(), + Password: d.MetaPassword, + Refresh: false, + }) + }) + if err != nil { + return nil, err + } + var files []model.Obj + for _, f := range resp.Data.Content { + file := model.ObjThumb{ + Object: model.Object{ + Name: f.Name, + Modified: f.Modified, + Ctime: f.Created, + Size: f.Size, + IsFolder: f.IsDir, + HashInfo: utils.FromString(f.HashInfo), + }, + Thumbnail: model.Thumbnail{Thumbnail: f.Thumb}, + } + files = append(files, &file) + } + return files, nil +} + +func (d *AListV3) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var resp common.Resp[FsGetResp] + // if PassUAToUpsteam is true, then pass the user-agent to the upstream + userAgent := base.UserAgent + if d.PassUAToUpsteam { + userAgent = args.Header.Get("user-agent") + if userAgent == "" { + userAgent = base.UserAgent + } + } + _, err := d.request("/fs/get", http.MethodPost, func(req *resty.Request) { + req.SetResult(&resp).SetBody(FsGetReq{ + Path: file.GetPath(), + Password: d.MetaPassword, + }).SetHeader("user-agent", userAgent) + }) + if err != nil { + return nil, err + } + return &model.Link{ + URL: resp.Data.RawURL, + }, nil +} + +func (d *AListV3) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + _, err := d.request("/fs/mkdir", http.MethodPost, func(req *resty.Request) { + req.SetBody(MkdirOrLinkReq{ + Path: path.Join(parentDir.GetPath(), dirName), + }) + }) + return err +} + +func (d *AListV3) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + _, err := d.request("/fs/move", http.MethodPost, func(req *resty.Request) { + req.SetBody(MoveCopyReq{ + SrcDir: path.Dir(srcObj.GetPath()), + DstDir: dstDir.GetPath(), + Names: []string{srcObj.GetName()}, + }) + }) + return err +} + +func (d *AListV3) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + _, err := d.request("/fs/rename", http.MethodPost, func(req *resty.Request) { + req.SetBody(RenameReq{ + Path: srcObj.GetPath(), + Name: newName, + }) + }) + return err +} + +func (d *AListV3) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + _, err := d.request("/fs/copy", http.MethodPost, func(req *resty.Request) { + req.SetBody(MoveCopyReq{ + SrcDir: path.Dir(srcObj.GetPath()), + DstDir: dstDir.GetPath(), + Names: []string{srcObj.GetName()}, + }) + }) + return err +} + +func (d *AListV3) Remove(ctx context.Context, obj model.Obj) error { + _, err := d.request("/fs/remove", http.MethodPost, func(req *resty.Request) { + req.SetBody(RemoveReq{ + Dir: path.Dir(obj.GetPath()), + Names: []string{obj.GetName()}, + }) + }) + return err +} + +func (d *AListV3) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + req, err := http.NewRequestWithContext(ctx, http.MethodPut, d.Address+"/api/fs/put", stream) + if err != nil { + return err + } + req.Header.Set("Authorization", d.Token) + req.Header.Set("File-Path", path.Join(dstDir.GetPath(), stream.GetName())) + req.Header.Set("Password", d.MetaPassword) + + req.ContentLength = stream.GetSize() + // client := base.NewHttpClient() + // client.Timeout = time.Hour * 6 + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + + bytes, err := io.ReadAll(res.Body) + if err != nil { + return err + } + log.Debugf("[alist_v3] response body: %s", string(bytes)) + if res.StatusCode >= 400 { + return fmt.Errorf("request failed, status: %s", res.Status) + } + code := utils.Json.Get(bytes, "code").ToInt() + if code != 200 { + if code == 401 || code == 403 { + err = d.login() + if err != nil { + return err + } + } + return fmt.Errorf("request failed,code: %d, message: %s", code, utils.Json.Get(bytes, "message").ToString()) + } + return nil +} + +//func (d *AList) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*AListV3)(nil) diff --git a/drivers/alist_v3/meta.go b/drivers/alist_v3/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..cc5f21893955d34ab18463bfed930618ab4baea6 --- /dev/null +++ b/drivers/alist_v3/meta.go @@ -0,0 +1,30 @@ +package alist_v3 + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + Address string `json:"url" required:"true"` + MetaPassword string `json:"meta_password"` + Username string `json:"username"` + Password string `json:"password"` + Token string `json:"token"` + PassUAToUpsteam bool `json:"pass_ua_to_upsteam" default:"true"` +} + +var config = driver.Config{ + Name: "AList V3", + LocalSort: true, + DefaultRoot: "/", + CheckStatus: true, + ProxyRangeOption: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &AListV3{} + }) +} diff --git a/drivers/alist_v3/types.go b/drivers/alist_v3/types.go new file mode 100644 index 0000000000000000000000000000000000000000..e517307f3ef888b165fd2e3f25b7116f61778d4c --- /dev/null +++ b/drivers/alist_v3/types.go @@ -0,0 +1,83 @@ +package alist_v3 + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type ListReq struct { + model.PageReq + Path string `json:"path" form:"path"` + Password string `json:"password" form:"password"` + Refresh bool `json:"refresh"` +} + +type ObjResp struct { + Name string `json:"name"` + Size int64 `json:"size"` + IsDir bool `json:"is_dir"` + Modified time.Time `json:"modified"` + Created time.Time `json:"created"` + Sign string `json:"sign"` + Thumb string `json:"thumb"` + Type int `json:"type"` + HashInfo string `json:"hashinfo"` +} + +type FsListResp struct { + Content []ObjResp `json:"content"` + Total int64 `json:"total"` + Readme string `json:"readme"` + Write bool `json:"write"` + Provider string `json:"provider"` +} + +type FsGetReq struct { + Path string `json:"path" form:"path"` + Password string `json:"password" form:"password"` +} + +type FsGetResp struct { + ObjResp + RawURL string `json:"raw_url"` + Readme string `json:"readme"` + Provider string `json:"provider"` + Related []ObjResp `json:"related"` +} + +type MkdirOrLinkReq struct { + Path string `json:"path" form:"path"` +} + +type MoveCopyReq struct { + SrcDir string `json:"src_dir"` + DstDir string `json:"dst_dir"` + Names []string `json:"names"` +} + +type RenameReq struct { + Path string `json:"path"` + Name string `json:"name"` +} + +type RemoveReq struct { + Dir string `json:"dir"` + Names []string `json:"names"` +} + +type LoginResp struct { + Token string `json:"token"` +} + +type MeResp struct { + Id int `json:"id"` + Username string `json:"username"` + Password string `json:"password"` + BasePath string `json:"base_path"` + Role int `json:"role"` + Disabled bool `json:"disabled"` + Permission int `json:"permission"` + SsoId string `json:"sso_id"` + Otp bool `json:"otp"` +} diff --git a/drivers/alist_v3/util.go b/drivers/alist_v3/util.go new file mode 100644 index 0000000000000000000000000000000000000000..5ede285af5b9d5e51a86a87c6a3b19c3a8a73165 --- /dev/null +++ b/drivers/alist_v3/util.go @@ -0,0 +1,61 @@ +package alist_v3 + +import ( + "fmt" + "net/http" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +func (d *AListV3) login() error { + if d.Username == "" { + return nil + } + var resp common.Resp[LoginResp] + _, err := d.request("/auth/login", http.MethodPost, func(req *resty.Request) { + req.SetResult(&resp).SetBody(base.Json{ + "username": d.Username, + "password": d.Password, + }) + }) + if err != nil { + return err + } + d.Token = resp.Data.Token + op.MustSaveDriverStorage(d) + return nil +} + +func (d *AListV3) request(api, method string, callback base.ReqCallback, retry ...bool) ([]byte, error) { + url := d.Address + "/api" + api + req := base.RestyClient.R() + req.SetHeader("Authorization", d.Token) + if callback != nil { + callback(req) + } + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + log.Debugf("[alist_v3] response body: %s", res.String()) + if res.StatusCode() >= 400 { + return nil, fmt.Errorf("request failed, status: %s", res.Status()) + } + code := utils.Json.Get(res.Body(), "code").ToInt() + if code != 200 { + if (code == 401 || code == 403) && !utils.IsBool(retry...) { + err = d.login() + if err != nil { + return nil, err + } + return d.request(api, method, callback, true) + } + return nil, fmt.Errorf("request failed,code: %d, message: %s", code, utils.Json.Get(res.Body(), "message").ToString()) + } + return res.Body(), nil +} diff --git a/drivers/aliyundrive/driver.go b/drivers/aliyundrive/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..2a977aa35e50479b18864ee5707044eaa5517be0 --- /dev/null +++ b/drivers/aliyundrive/driver.go @@ -0,0 +1,355 @@ +package aliyundrive + +import ( + "bytes" + "context" + "crypto/sha1" + "encoding/base64" + "encoding/hex" + "fmt" + "io" + "math" + "math/big" + "net/http" + "os" + "time" + + "github.com/alist-org/alist/v3/internal/stream" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/cron" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +type AliDrive struct { + model.Storage + Addition + AccessToken string + cron *cron.Cron + DriveId string + UserID string +} + +func (d *AliDrive) Config() driver.Config { + return config +} + +func (d *AliDrive) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *AliDrive) Init(ctx context.Context) error { + // TODO login / refresh token + //op.MustSaveDriverStorage(d) + err := d.refreshToken() + if err != nil { + return err + } + // get driver id + res, err, _ := d.request("https://api.alipan.com/v2/user/get", http.MethodPost, nil, nil) + if err != nil { + return err + } + d.DriveId = utils.Json.Get(res, "default_drive_id").ToString() + d.UserID = utils.Json.Get(res, "user_id").ToString() + d.cron = cron.NewCron(time.Hour * 2) + d.cron.Do(func() { + err := d.refreshToken() + if err != nil { + log.Errorf("%+v", err) + } + }) + if global.Has(d.UserID) { + return nil + } + // init deviceID + deviceID := utils.HashData(utils.SHA256, []byte(d.UserID)) + // init privateKey + privateKey, _ := NewPrivateKeyFromHex(deviceID) + state := State{ + privateKey: privateKey, + deviceID: deviceID, + } + // store state + global.Store(d.UserID, &state) + // init signature + d.sign() + return nil +} + +func (d *AliDrive) Drop(ctx context.Context) error { + if d.cron != nil { + d.cron.Stop() + } + return nil +} + +func (d *AliDrive) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *AliDrive) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + data := base.Json{ + "drive_id": d.DriveId, + "file_id": file.GetID(), + "expire_sec": 14400, + } + res, err, _ := d.request("https://api.alipan.com/v2/file/get_download_url", http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + if err != nil { + return nil, err + } + return &model.Link{ + Header: http.Header{ + "Referer": []string{"https://www.alipan.com/"}, + }, + URL: utils.Json.Get(res, "url").ToString(), + }, nil +} + +func (d *AliDrive) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + _, err, _ := d.request("https://api.alipan.com/adrive/v2/file/createWithFolders", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "check_name_mode": "refuse", + "drive_id": d.DriveId, + "name": dirName, + "parent_file_id": parentDir.GetID(), + "type": "folder", + }) + }, nil) + return err +} + +func (d *AliDrive) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + err := d.batch(srcObj.GetID(), dstDir.GetID(), "/file/move") + return err +} + +func (d *AliDrive) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + _, err, _ := d.request("https://api.alipan.com/v3/file/update", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "check_name_mode": "refuse", + "drive_id": d.DriveId, + "file_id": srcObj.GetID(), + "name": newName, + }) + }, nil) + return err +} + +func (d *AliDrive) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + err := d.batch(srcObj.GetID(), dstDir.GetID(), "/file/copy") + return err +} + +func (d *AliDrive) Remove(ctx context.Context, obj model.Obj) error { + _, err, _ := d.request("https://api.alipan.com/v2/recyclebin/trash", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "drive_id": d.DriveId, + "file_id": obj.GetID(), + }) + }, nil) + return err +} + +func (d *AliDrive) Put(ctx context.Context, dstDir model.Obj, streamer model.FileStreamer, up driver.UpdateProgress) error { + file := stream.FileStream{ + Obj: streamer, + Reader: streamer, + Mimetype: streamer.GetMimetype(), + } + const DEFAULT int64 = 10485760 + var count = int(math.Ceil(float64(streamer.GetSize()) / float64(DEFAULT))) + + partInfoList := make([]base.Json, 0, count) + for i := 1; i <= count; i++ { + partInfoList = append(partInfoList, base.Json{"part_number": i}) + } + reqBody := base.Json{ + "check_name_mode": "overwrite", + "drive_id": d.DriveId, + "name": file.GetName(), + "parent_file_id": dstDir.GetID(), + "part_info_list": partInfoList, + "size": file.GetSize(), + "type": "file", + } + + var localFile *os.File + if fileStream, ok := file.Reader.(*stream.FileStream); ok { + localFile, _ = fileStream.Reader.(*os.File) + } + if d.RapidUpload { + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + utils.CopyWithBufferN(buf, file, 1024) + reqBody["pre_hash"] = utils.HashData(utils.SHA1, buf.Bytes()) + if localFile != nil { + if _, err := localFile.Seek(0, io.SeekStart); err != nil { + return err + } + } else { + // 把头部拼接回去 + file.Reader = struct { + io.Reader + io.Closer + }{ + Reader: io.MultiReader(buf, file), + Closer: &file, + } + } + } else { + reqBody["content_hash_name"] = "none" + reqBody["proof_version"] = "v1" + } + + var resp UploadResp + _, err, e := d.request("https://api.alipan.com/adrive/v2/file/createWithFolders", http.MethodPost, func(req *resty.Request) { + req.SetBody(reqBody) + }, &resp) + + if err != nil && e.Code != "PreHashMatched" { + return err + } + + if d.RapidUpload && e.Code == "PreHashMatched" { + delete(reqBody, "pre_hash") + h := sha1.New() + if localFile != nil { + if err = utils.CopyWithCtx(ctx, h, localFile, 0, nil); err != nil { + return err + } + if _, err = localFile.Seek(0, io.SeekStart); err != nil { + return err + } + } else { + tempFile, err := os.CreateTemp(conf.Conf.TempDir, "file-*") + if err != nil { + return err + } + defer func() { + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) + }() + if err = utils.CopyWithCtx(ctx, io.MultiWriter(tempFile, h), file, 0, nil); err != nil { + return err + } + localFile = tempFile + } + reqBody["content_hash"] = hex.EncodeToString(h.Sum(nil)) + reqBody["content_hash_name"] = "sha1" + reqBody["proof_version"] = "v1" + + /* + js 隐性转换太坑不知道有没有bug + var n = e.access_token, + r = new BigNumber('0x'.concat(md5(n).slice(0, 16))), + i = new BigNumber(t.file.size), + o = i ? r.mod(i) : new gt.BigNumber(0); + (t.file.slice(o.toNumber(), Math.min(o.plus(8).toNumber(), t.file.size))) + */ + buf := make([]byte, 8) + r, _ := new(big.Int).SetString(utils.GetMD5EncodeStr(d.AccessToken)[:16], 16) + i := new(big.Int).SetInt64(file.GetSize()) + o := new(big.Int).SetInt64(0) + if file.GetSize() > 0 { + o = r.Mod(r, i) + } + n, _ := io.NewSectionReader(localFile, o.Int64(), 8).Read(buf[:8]) + reqBody["proof_code"] = base64.StdEncoding.EncodeToString(buf[:n]) + + _, err, e := d.request("https://api.alipan.com/adrive/v2/file/createWithFolders", http.MethodPost, func(req *resty.Request) { + req.SetBody(reqBody) + }, &resp) + if err != nil && e.Code != "PreHashMatched" { + return err + } + if resp.RapidUpload { + return nil + } + // 秒传失败 + if _, err = localFile.Seek(0, io.SeekStart); err != nil { + return err + } + file.Reader = localFile + } + + for i, partInfo := range resp.PartInfoList { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + url := partInfo.UploadUrl + if d.InternalUpload { + url = partInfo.InternalUploadUrl + } + req, err := http.NewRequest("PUT", url, io.LimitReader(file, DEFAULT)) + if err != nil { + return err + } + req = req.WithContext(ctx) + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + res.Body.Close() + if count > 0 { + up(float64(i) * 100 / float64(count)) + } + } + var resp2 base.Json + _, err, e = d.request("https://api.alipan.com/v2/file/complete", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "drive_id": d.DriveId, + "file_id": resp.FileId, + "upload_id": resp.UploadId, + }) + }, &resp2) + if err != nil && e.Code != "PreHashMatched" { + return err + } + if resp2["file_id"] == resp.FileId { + return nil + } + return fmt.Errorf("%+v", resp2) +} + +func (d *AliDrive) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { + var resp base.Json + var url string + data := base.Json{ + "drive_id": d.DriveId, + "file_id": args.Obj.GetID(), + } + switch args.Method { + case "doc_preview": + url = "https://api.alipan.com/v2/file/get_office_preview_url" + data["access_token"] = d.AccessToken + case "video_preview": + url = "https://api.alipan.com/v2/file/get_video_preview_play_info" + data["category"] = "live_transcoding" + data["url_expire_sec"] = 14400 + default: + return nil, errs.NotSupport + } + _, err, _ := d.request(url, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, &resp) + if err != nil { + return nil, err + } + return resp, nil +} + +var _ driver.Driver = (*AliDrive)(nil) diff --git a/drivers/aliyundrive/global.go b/drivers/aliyundrive/global.go new file mode 100644 index 0000000000000000000000000000000000000000..a0929780656dea4a4c917c91cfeb3bf96064d4f8 --- /dev/null +++ b/drivers/aliyundrive/global.go @@ -0,0 +1,16 @@ +package aliyundrive + +import ( + "crypto/ecdsa" + + "github.com/alist-org/alist/v3/pkg/generic_sync" +) + +type State struct { + deviceID string + signature string + retry int + privateKey *ecdsa.PrivateKey +} + +var global = generic_sync.MapOf[string, *State]{} diff --git a/drivers/aliyundrive/help.go b/drivers/aliyundrive/help.go new file mode 100644 index 0000000000000000000000000000000000000000..2037f545530618e1264e4e0e59b263f3a9806e0e --- /dev/null +++ b/drivers/aliyundrive/help.go @@ -0,0 +1,66 @@ +package aliyundrive + +import ( + "crypto/ecdsa" + "crypto/rand" + "encoding/hex" + "math/big" + + "github.com/dustinxie/ecc" +) + +func NewPrivateKey() (*ecdsa.PrivateKey, error) { + p256k1 := ecc.P256k1() + return ecdsa.GenerateKey(p256k1, rand.Reader) +} + +func NewPrivateKeyFromHex(hex_ string) (*ecdsa.PrivateKey, error) { + data, err := hex.DecodeString(hex_) + if err != nil { + return nil, err + } + return NewPrivateKeyFromBytes(data), nil + +} + +func NewPrivateKeyFromBytes(priv []byte) *ecdsa.PrivateKey { + p256k1 := ecc.P256k1() + x, y := p256k1.ScalarBaseMult(priv) + return &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: p256k1, + X: x, + Y: y, + }, + D: new(big.Int).SetBytes(priv), + } +} + +func PrivateKeyToHex(private *ecdsa.PrivateKey) string { + return hex.EncodeToString(PrivateKeyToBytes(private)) +} + +func PrivateKeyToBytes(private *ecdsa.PrivateKey) []byte { + return private.D.Bytes() +} + +func PublicKeyToHex(public *ecdsa.PublicKey) string { + return hex.EncodeToString(PublicKeyToBytes(public)) +} + +func PublicKeyToBytes(public *ecdsa.PublicKey) []byte { + x := public.X.Bytes() + if len(x) < 32 { + for i := 0; i < 32-len(x); i++ { + x = append([]byte{0}, x...) + } + } + + y := public.Y.Bytes() + if len(y) < 32 { + for i := 0; i < 32-len(y); i++ { + y = append([]byte{0}, y...) + } + } + return append(x, y...) +} diff --git a/drivers/aliyundrive/meta.go b/drivers/aliyundrive/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..9aee856908d238e8105fc13554cf350fd9c9f759 --- /dev/null +++ b/drivers/aliyundrive/meta.go @@ -0,0 +1,30 @@ +package aliyundrive + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootID + RefreshToken string `json:"refresh_token" required:"true"` + //DeviceID string `json:"device_id" required:"true"` + OrderBy string `json:"order_by" type:"select" options:"name,size,updated_at,created_at"` + OrderDirection string `json:"order_direction" type:"select" options:"ASC,DESC"` + RapidUpload bool `json:"rapid_upload"` + InternalUpload bool `json:"internal_upload"` +} + +var config = driver.Config{ + Name: "Aliyundrive", + DefaultRoot: "root", + Alert: `warning|There may be an infinite loop bug in this driver. +Deprecated, no longer maintained and will be removed in a future version. +We recommend using the official driver AliyundriveOpen.`, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &AliDrive{} + }) +} diff --git a/drivers/aliyundrive/types.go b/drivers/aliyundrive/types.go new file mode 100644 index 0000000000000000000000000000000000000000..e74d5f58c6c134ea907bcc6894ba537db1c8182c --- /dev/null +++ b/drivers/aliyundrive/types.go @@ -0,0 +1,56 @@ +package aliyundrive + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type RespErr struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type Files struct { + Items []File `json:"items"` + NextMarker string `json:"next_marker"` +} + +type File struct { + DriveId string `json:"drive_id"` + CreatedAt *time.Time `json:"created_at"` + FileExtension string `json:"file_extension"` + FileId string `json:"file_id"` + Type string `json:"type"` + Name string `json:"name"` + Category string `json:"category"` + ParentFileId string `json:"parent_file_id"` + UpdatedAt time.Time `json:"updated_at"` + Size int64 `json:"size"` + Thumbnail string `json:"thumbnail"` + Url string `json:"url"` +} + +func fileToObj(f File) *model.ObjThumb { + return &model.ObjThumb{ + Object: model.Object{ + ID: f.FileId, + Name: f.Name, + Size: f.Size, + Modified: f.UpdatedAt, + IsFolder: f.Type == "folder", + }, + Thumbnail: model.Thumbnail{Thumbnail: f.Thumbnail}, + } +} + +type UploadResp struct { + FileId string `json:"file_id"` + UploadId string `json:"upload_id"` + PartInfoList []struct { + UploadUrl string `json:"upload_url"` + InternalUploadUrl string `json:"internal_upload_url"` + } `json:"part_info_list"` + + RapidUpload bool `json:"rapid_upload"` +} diff --git a/drivers/aliyundrive/util.go b/drivers/aliyundrive/util.go new file mode 100644 index 0000000000000000000000000000000000000000..0e81b082bb911dbb6c5f9e6285008e6a80102b41 --- /dev/null +++ b/drivers/aliyundrive/util.go @@ -0,0 +1,204 @@ +package aliyundrive + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "net/http" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/dustinxie/ecc" + "github.com/go-resty/resty/v2" + "github.com/google/uuid" +) + +func (d *AliDrive) createSession() error { + state, ok := global.Load(d.UserID) + if !ok { + return fmt.Errorf("can't load user state, user_id: %s", d.UserID) + } + d.sign() + state.retry++ + if state.retry > 3 { + state.retry = 0 + return fmt.Errorf("createSession failed after three retries") + } + _, err, _ := d.request("https://api.alipan.com/users/v1/users/device/create_session", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "deviceName": "samsung", + "modelName": "SM-G9810", + "nonce": 0, + "pubKey": PublicKeyToHex(&state.privateKey.PublicKey), + "refreshToken": d.RefreshToken, + }) + }, nil) + if err == nil{ + state.retry = 0 + } + return err +} + +// func (d *AliDrive) renewSession() error { +// _, err, _ := d.request("https://api.alipan.com/users/v1/users/device/renew_session", http.MethodPost, nil, nil) +// return err +// } + +func (d *AliDrive) sign() { + state, _ := global.Load(d.UserID) + secpAppID := "5dde4e1bdf9e4966b387ba58f4b3fdc3" + singdata := fmt.Sprintf("%s:%s:%s:%d", secpAppID, state.deviceID, d.UserID, 0) + hash := sha256.Sum256([]byte(singdata)) + data, _ := ecc.SignBytes(state.privateKey, hash[:], ecc.RecID|ecc.LowerS) + state.signature = hex.EncodeToString(data) //strconv.Itoa(state.nonce) +} + +// do others that not defined in Driver interface + +func (d *AliDrive) refreshToken() error { + url := "https://auth.alipan.com/v2/account/token" + var resp base.TokenResp + var e RespErr + _, err := base.RestyClient.R(). + //ForceContentType("application/json"). + SetBody(base.Json{"refresh_token": d.RefreshToken, "grant_type": "refresh_token"}). + SetResult(&resp). + SetError(&e). + Post(url) + if err != nil { + return err + } + if e.Code != "" { + return fmt.Errorf("failed to refresh token: %s", e.Message) + } + if resp.RefreshToken == "" { + return errors.New("failed to refresh token: refresh token is empty") + } + d.RefreshToken, d.AccessToken = resp.RefreshToken, resp.AccessToken + op.MustSaveDriverStorage(d) + return nil +} + +func (d *AliDrive) request(url, method string, callback base.ReqCallback, resp interface{}) ([]byte, error, RespErr) { + req := base.RestyClient.R() + state, ok := global.Load(d.UserID) + if !ok { + if url == "https://api.alipan.com/v2/user/get" { + state = &State{} + } else { + return nil, fmt.Errorf("can't load user state, user_id: %s", d.UserID), RespErr{} + } + } + req.SetHeaders(map[string]string{ + "Authorization": "Bearer\t" + d.AccessToken, + "content-type": "application/json", + "origin": "https://www.alipan.com", + "Referer": "https://alipan.com/", + "X-Signature": state.signature, + "x-request-id": uuid.NewString(), + "X-Canary": "client=Android,app=adrive,version=v4.1.0", + "X-Device-Id": state.deviceID, + }) + if callback != nil { + callback(req) + } else { + req.SetBody("{}") + } + if resp != nil { + req.SetResult(resp) + } + var e RespErr + req.SetError(&e) + res, err := req.Execute(method, url) + if err != nil { + return nil, err, e + } + if e.Code != "" { + switch e.Code { + case "AccessTokenInvalid": + err = d.refreshToken() + if err != nil { + return nil, err, e + } + case "DeviceSessionSignatureInvalid": + err = d.createSession() + if err != nil { + return nil, err, e + } + default: + return nil, errors.New(e.Message), e + } + return d.request(url, method, callback, resp) + } else if res.IsError() { + return nil, errors.New("bad status code " + res.Status()), e + } + return res.Body(), nil, e +} + +func (d *AliDrive) getFiles(fileId string) ([]File, error) { + marker := "first" + res := make([]File, 0) + for marker != "" { + if marker == "first" { + marker = "" + } + var resp Files + data := base.Json{ + "drive_id": d.DriveId, + "fields": "*", + "image_thumbnail_process": "image/resize,w_400/format,jpeg", + "image_url_process": "image/resize,w_1920/format,jpeg", + "limit": 200, + "marker": marker, + "order_by": d.OrderBy, + "order_direction": d.OrderDirection, + "parent_file_id": fileId, + "video_thumbnail_process": "video/snapshot,t_0,f_jpg,ar_auto,w_300", + "url_expire_sec": 14400, + } + _, err, _ := d.request("https://api.alipan.com/v2/file/list", http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, &resp) + + if err != nil { + return nil, err + } + marker = resp.NextMarker + res = append(res, resp.Items...) + } + return res, nil +} + +func (d *AliDrive) batch(srcId, dstId string, url string) error { + res, err, _ := d.request("https://api.alipan.com/v3/batch", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "requests": []base.Json{ + { + "headers": base.Json{ + "Content-Type": "application/json", + }, + "method": "POST", + "id": srcId, + "body": base.Json{ + "drive_id": d.DriveId, + "file_id": srcId, + "to_drive_id": d.DriveId, + "to_parent_file_id": dstId, + }, + "url": url, + }, + }, + "resource": "file", + }) + }, nil) + if err != nil { + return err + } + status := utils.Json.Get(res, "responses", 0, "status").ToInt() + if status < 400 && status >= 100 { + return nil + } + return errors.New(string(res)) +} diff --git a/drivers/aliyundrive_open/driver.go b/drivers/aliyundrive_open/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..a65ba05c5af9505665817ca4899fc230031fc2aa --- /dev/null +++ b/drivers/aliyundrive_open/driver.go @@ -0,0 +1,237 @@ +package aliyundrive_open + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + "github.com/Xhofe/rateg" + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +type AliyundriveOpen struct { + model.Storage + Addition + + DriveId string + + limitList func(ctx context.Context, data base.Json) (*Files, error) + limitLink func(ctx context.Context, file model.Obj) (*model.Link, error) + ref *AliyundriveOpen +} + +func (d *AliyundriveOpen) Config() driver.Config { + return config +} + +func (d *AliyundriveOpen) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *AliyundriveOpen) Init(ctx context.Context) error { + if d.LIVPDownloadFormat == "" { + d.LIVPDownloadFormat = "jpeg" + } + if d.DriveType == "" { + d.DriveType = "default" + } + res, err := d.request("/adrive/v1.0/user/getDriveInfo", http.MethodPost, nil) + if err != nil { + return err + } + d.DriveId = utils.Json.Get(res, d.DriveType+"_drive_id").ToString() + d.limitList = rateg.LimitFnCtx(d.list, rateg.LimitFnOption{ + Limit: 4, + Bucket: 1, + }) + d.limitLink = rateg.LimitFnCtx(d.link, rateg.LimitFnOption{ + Limit: 1, + Bucket: 1, + }) + return nil +} + +func (d *AliyundriveOpen) InitReference(storage driver.Driver) error { + refStorage, ok := storage.(*AliyundriveOpen) + if ok { + d.ref = refStorage + return nil + } + return errs.NotSupport +} + +func (d *AliyundriveOpen) Drop(ctx context.Context) error { + d.ref = nil + return nil +} + +func (d *AliyundriveOpen) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if d.limitList == nil { + return nil, fmt.Errorf("driver not init") + } + files, err := d.getFiles(ctx, dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *AliyundriveOpen) link(ctx context.Context, file model.Obj) (*model.Link, error) { + res, err := d.request("/adrive/v1.0/openFile/getDownloadUrl", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "drive_id": d.DriveId, + "file_id": file.GetID(), + "expire_sec": 14400, + }) + }) + if err != nil { + return nil, err + } + url := utils.Json.Get(res, "url").ToString() + if url == "" { + if utils.Ext(file.GetName()) != "livp" { + return nil, errors.New("get download url failed: " + string(res)) + } + url = utils.Json.Get(res, "streamsUrl", d.LIVPDownloadFormat).ToString() + } + exp := time.Minute + return &model.Link{ + URL: url, + Expiration: &exp, + }, nil +} + +func (d *AliyundriveOpen) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if d.limitLink == nil { + return nil, fmt.Errorf("driver not init") + } + return d.limitLink(ctx, file) +} + +func (d *AliyundriveOpen) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + nowTime, _ := getNowTime() + newDir := File{CreatedAt: nowTime, UpdatedAt: nowTime} + _, err := d.request("/adrive/v1.0/openFile/create", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "drive_id": d.DriveId, + "parent_file_id": parentDir.GetID(), + "name": dirName, + "type": "folder", + "check_name_mode": "refuse", + }).SetResult(&newDir) + }) + if err != nil { + return nil, err + } + return fileToObj(newDir), nil +} + +func (d *AliyundriveOpen) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + var resp MoveOrCopyResp + _, err := d.request("/adrive/v1.0/openFile/move", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "drive_id": d.DriveId, + "file_id": srcObj.GetID(), + "to_parent_file_id": dstDir.GetID(), + "check_name_mode": "refuse", // optional:ignore,auto_rename,refuse + //"new_name": "newName", // The new name to use when a file of the same name exists + }).SetResult(&resp) + }) + if err != nil { + return nil, err + } + if resp.Exist { + return nil, errors.New("existence of files with the same name") + } + + if srcObj, ok := srcObj.(*model.ObjThumb); ok { + srcObj.ID = resp.FileID + srcObj.Modified = time.Now() + return srcObj, nil + } + return nil, nil +} + +func (d *AliyundriveOpen) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + var newFile File + _, err := d.request("/adrive/v1.0/openFile/update", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "drive_id": d.DriveId, + "file_id": srcObj.GetID(), + "name": newName, + }).SetResult(&newFile) + }) + if err != nil { + return nil, err + } + return fileToObj(newFile), nil +} + +func (d *AliyundriveOpen) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + _, err := d.request("/adrive/v1.0/openFile/copy", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "drive_id": d.DriveId, + "file_id": srcObj.GetID(), + "to_parent_file_id": dstDir.GetID(), + "auto_rename": true, + }) + }) + return err +} + +func (d *AliyundriveOpen) Remove(ctx context.Context, obj model.Obj) error { + uri := "/adrive/v1.0/openFile/recyclebin/trash" + if d.RemoveWay == "delete" { + uri = "/adrive/v1.0/openFile/delete" + } + _, err := d.request(uri, http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "drive_id": d.DriveId, + "file_id": obj.GetID(), + }) + }) + return err +} + +func (d *AliyundriveOpen) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + return d.upload(ctx, dstDir, stream, up) +} + +func (d *AliyundriveOpen) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { + var resp base.Json + var uri string + data := base.Json{ + "drive_id": d.DriveId, + "file_id": args.Obj.GetID(), + } + switch args.Method { + case "video_preview": + uri = "/adrive/v1.0/openFile/getVideoPreviewPlayInfo" + data["category"] = "live_transcoding" + data["url_expire_sec"] = 14400 + default: + return nil, errs.NotSupport + } + _, err := d.request(uri, http.MethodPost, func(req *resty.Request) { + req.SetBody(data).SetResult(&resp) + }) + if err != nil { + return nil, err + } + return resp, nil +} + +var _ driver.Driver = (*AliyundriveOpen)(nil) +var _ driver.MkdirResult = (*AliyundriveOpen)(nil) +var _ driver.MoveResult = (*AliyundriveOpen)(nil) +var _ driver.RenameResult = (*AliyundriveOpen)(nil) +var _ driver.PutResult = (*AliyundriveOpen)(nil) diff --git a/drivers/aliyundrive_open/meta.go b/drivers/aliyundrive_open/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..03f97f8b795fd34bf804089154e64f933f8a767b --- /dev/null +++ b/drivers/aliyundrive_open/meta.go @@ -0,0 +1,41 @@ +package aliyundrive_open + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + DriveType string `json:"drive_type" type:"select" options:"default,resource,backup" default:"resource"` + driver.RootID + RefreshToken string `json:"refresh_token" required:"true"` + OrderBy string `json:"order_by" type:"select" options:"name,size,updated_at,created_at"` + OrderDirection string `json:"order_direction" type:"select" options:"ASC,DESC"` + OauthTokenURL string `json:"oauth_token_url" default:"https://api.nn.ci/alist/ali_open/token"` + ClientID string `json:"client_id" required:"false" help:"Keep it empty if you don't have one"` + ClientSecret string `json:"client_secret" required:"false" help:"Keep it empty if you don't have one"` + RemoveWay string `json:"remove_way" required:"true" type:"select" options:"trash,delete"` + RapidUpload bool `json:"rapid_upload" help:"If you enable this option, the file will be uploaded to the server first, so the progress will be incorrect"` + InternalUpload bool `json:"internal_upload" help:"If you are using Aliyun ECS is located in Beijing, you can turn it on to boost the upload speed"` + LIVPDownloadFormat string `json:"livp_download_format" type:"select" options:"jpeg,mov" default:"jpeg"` + AccessToken string +} + +var config = driver.Config{ + Name: "AliyundriveOpen", + LocalSort: false, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: false, + NeedMs: false, + DefaultRoot: "root", + NoOverwriteUpload: true, +} +var API_URL = "https://openapi.alipan.com" + +func init() { + op.RegisterDriver(func() driver.Driver { + return &AliyundriveOpen{} + }) +} diff --git a/drivers/aliyundrive_open/types.go b/drivers/aliyundrive_open/types.go new file mode 100644 index 0000000000000000000000000000000000000000..46830a5133645076c6221b60dd909defabc30e15 --- /dev/null +++ b/drivers/aliyundrive_open/types.go @@ -0,0 +1,84 @@ +package aliyundrive_open + +import ( + "github.com/alist-org/alist/v3/pkg/utils" + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type ErrResp struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type Files struct { + Items []File `json:"items"` + NextMarker string `json:"next_marker"` +} + +type File struct { + DriveId string `json:"drive_id"` + FileId string `json:"file_id"` + ParentFileId string `json:"parent_file_id"` + Name string `json:"name"` + Size int64 `json:"size"` + FileExtension string `json:"file_extension"` + ContentHash string `json:"content_hash"` + Category string `json:"category"` + Type string `json:"type"` + Thumbnail string `json:"thumbnail"` + Url string `json:"url"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + + // create only + FileName string `json:"file_name"` +} + +func fileToObj(f File) *model.ObjThumb { + if f.Name == "" { + f.Name = f.FileName + } + return &model.ObjThumb{ + Object: model.Object{ + ID: f.FileId, + Name: f.Name, + Size: f.Size, + Modified: f.UpdatedAt, + IsFolder: f.Type == "folder", + Ctime: f.CreatedAt, + HashInfo: utils.NewHashInfo(utils.SHA1, f.ContentHash), + }, + Thumbnail: model.Thumbnail{Thumbnail: f.Thumbnail}, + } +} + +type PartInfo struct { + Etag interface{} `json:"etag"` + PartNumber int `json:"part_number"` + PartSize interface{} `json:"part_size"` + UploadUrl string `json:"upload_url"` + ContentType string `json:"content_type"` +} + +type CreateResp struct { + //Type string `json:"type"` + //ParentFileId string `json:"parent_file_id"` + //DriveId string `json:"drive_id"` + FileId string `json:"file_id"` + //RevisionId string `json:"revision_id"` + //EncryptMode string `json:"encrypt_mode"` + //DomainId string `json:"domain_id"` + //FileName string `json:"file_name"` + UploadId string `json:"upload_id"` + //Location string `json:"location"` + RapidUpload bool `json:"rapid_upload"` + PartInfoList []PartInfo `json:"part_info_list"` +} + +type MoveOrCopyResp struct { + Exist bool `json:"exist"` + DriveID string `json:"drive_id"` + FileID string `json:"file_id"` +} diff --git a/drivers/aliyundrive_open/upload.go b/drivers/aliyundrive_open/upload.go new file mode 100644 index 0000000000000000000000000000000000000000..653a24423465a4f2f2d3f768383f0b9fa9a986be --- /dev/null +++ b/drivers/aliyundrive_open/upload.go @@ -0,0 +1,273 @@ +package aliyundrive_open + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + "io" + "math" + "net/http" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/avast/retry-go" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +func makePartInfos(size int) []base.Json { + partInfoList := make([]base.Json, size) + for i := 0; i < size; i++ { + partInfoList[i] = base.Json{"part_number": 1 + i} + } + return partInfoList +} + +func calPartSize(fileSize int64) int64 { + var partSize int64 = 20 * utils.MB + if fileSize > partSize { + if fileSize > 1*utils.TB { // file Size over 1TB + partSize = 5 * utils.GB // file part size 5GB + } else if fileSize > 768*utils.GB { // over 768GB + partSize = 109951163 // ≈ 104.8576MB, split 1TB into 10,000 part + } else if fileSize > 512*utils.GB { // over 512GB + partSize = 82463373 // ≈ 78.6432MB + } else if fileSize > 384*utils.GB { // over 384GB + partSize = 54975582 // ≈ 52.4288MB + } else if fileSize > 256*utils.GB { // over 256GB + partSize = 41231687 // ≈ 39.3216MB + } else if fileSize > 128*utils.GB { // over 128GB + partSize = 27487791 // ≈ 26.2144MB + } + } + return partSize +} + +func (d *AliyundriveOpen) getUploadUrl(count int, fileId, uploadId string) ([]PartInfo, error) { + partInfoList := makePartInfos(count) + var resp CreateResp + _, err := d.request("/adrive/v1.0/openFile/getUploadUrl", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "drive_id": d.DriveId, + "file_id": fileId, + "part_info_list": partInfoList, + "upload_id": uploadId, + }).SetResult(&resp) + }) + return resp.PartInfoList, err +} + +func (d *AliyundriveOpen) uploadPart(ctx context.Context, r io.Reader, partInfo PartInfo) error { + uploadUrl := partInfo.UploadUrl + if d.InternalUpload { + uploadUrl = strings.ReplaceAll(uploadUrl, "https://cn-beijing-data.aliyundrive.net/", "http://ccp-bj29-bj-1592982087.oss-cn-beijing-internal.aliyuncs.com/") + } + req, err := http.NewRequestWithContext(ctx, "PUT", uploadUrl, r) + if err != nil { + return err + } + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + res.Body.Close() + if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusConflict { + return fmt.Errorf("upload status: %d", res.StatusCode) + } + return nil +} + +func (d *AliyundriveOpen) completeUpload(fileId, uploadId string) (model.Obj, error) { + // 3. complete + var newFile File + _, err := d.request("/adrive/v1.0/openFile/complete", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "drive_id": d.DriveId, + "file_id": fileId, + "upload_id": uploadId, + }).SetResult(&newFile) + }) + if err != nil { + return nil, err + } + return fileToObj(newFile), nil +} + +type ProofRange struct { + Start int64 + End int64 +} + +func getProofRange(input string, size int64) (*ProofRange, error) { + if size == 0 { + return &ProofRange{}, nil + } + tmpStr := utils.GetMD5EncodeStr(input)[0:16] + tmpInt, err := strconv.ParseUint(tmpStr, 16, 64) + if err != nil { + return nil, err + } + index := tmpInt % uint64(size) + pr := &ProofRange{ + Start: int64(index), + End: int64(index) + 8, + } + if pr.End >= size { + pr.End = size + } + return pr, nil +} + +func (d *AliyundriveOpen) calProofCode(stream model.FileStreamer) (string, error) { + proofRange, err := getProofRange(d.getAccessToken(), stream.GetSize()) + if err != nil { + return "", err + } + length := proofRange.End - proofRange.Start + buf := bytes.NewBuffer(make([]byte, 0, length)) + reader, err := stream.RangeRead(http_range.Range{Start: proofRange.Start, Length: length}) + if err != nil { + return "", err + } + _, err = utils.CopyWithBufferN(buf, reader, length) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(buf.Bytes()), nil +} + +func (d *AliyundriveOpen) upload(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + // 1. create + // Part Size Unit: Bytes, Default: 20MB, + // Maximum number of slices 10,000, ≈195.3125GB + var partSize = calPartSize(stream.GetSize()) + const dateFormat = "2006-01-02T15:04:05.000Z" + mtimeStr := stream.ModTime().UTC().Format(dateFormat) + ctimeStr := stream.CreateTime().UTC().Format(dateFormat) + + createData := base.Json{ + "drive_id": d.DriveId, + "parent_file_id": dstDir.GetID(), + "name": stream.GetName(), + "type": "file", + "check_name_mode": "ignore", + "local_modified_at": mtimeStr, + "local_created_at": ctimeStr, + } + count := int(math.Ceil(float64(stream.GetSize()) / float64(partSize))) + createData["part_info_list"] = makePartInfos(count) + // rapid upload + rapidUpload := !stream.IsForceStreamUpload() && stream.GetSize() > 100*utils.KB && d.RapidUpload + if rapidUpload { + log.Debugf("[aliyundrive_open] start cal pre_hash") + // read 1024 bytes to calculate pre hash + reader, err := stream.RangeRead(http_range.Range{Start: 0, Length: 1024}) + if err != nil { + return nil, err + } + hash, err := utils.HashReader(utils.SHA1, reader) + if err != nil { + return nil, err + } + createData["size"] = stream.GetSize() + createData["pre_hash"] = hash + } + var createResp CreateResp + _, err, e := d.requestReturnErrResp("/adrive/v1.0/openFile/create", http.MethodPost, func(req *resty.Request) { + req.SetBody(createData).SetResult(&createResp) + }) + var tmpF model.File + if err != nil { + if e.Code != "PreHashMatched" || !rapidUpload { + return nil, err + } + log.Debugf("[aliyundrive_open] pre_hash matched, start rapid upload") + + hi := stream.GetHash() + hash := hi.GetHash(utils.SHA1) + if len(hash) <= 0 { + tmpF, err = stream.CacheFullInTempFile() + if err != nil { + return nil, err + } + hash, err = utils.HashFile(utils.SHA1, tmpF) + if err != nil { + return nil, err + } + + } + + delete(createData, "pre_hash") + createData["proof_version"] = "v1" + createData["content_hash_name"] = "sha1" + createData["content_hash"] = hash + createData["proof_code"], err = d.calProofCode(stream) + if err != nil { + return nil, fmt.Errorf("cal proof code error: %s", err.Error()) + } + _, err = d.request("/adrive/v1.0/openFile/create", http.MethodPost, func(req *resty.Request) { + req.SetBody(createData).SetResult(&createResp) + }) + if err != nil { + return nil, err + } + } + + if !createResp.RapidUpload { + // 2. normal upload + log.Debugf("[aliyundive_open] normal upload") + + preTime := time.Now() + var offset, length int64 = 0, partSize + //var length + for i := 0; i < len(createResp.PartInfoList); i++ { + if utils.IsCanceled(ctx) { + return nil, ctx.Err() + } + // refresh upload url if 50 minutes passed + if time.Since(preTime) > 50*time.Minute { + createResp.PartInfoList, err = d.getUploadUrl(count, createResp.FileId, createResp.UploadId) + if err != nil { + return nil, err + } + preTime = time.Now() + } + if remain := stream.GetSize() - offset; length > remain { + length = remain + } + rd := utils.NewMultiReadable(io.LimitReader(stream, partSize)) + if rapidUpload { + srd, err := stream.RangeRead(http_range.Range{Start: offset, Length: length}) + if err != nil { + return nil, err + } + rd = utils.NewMultiReadable(srd) + } + err = retry.Do(func() error { + rd.Reset() + return d.uploadPart(ctx, rd, createResp.PartInfoList[i]) + }, + retry.Attempts(3), + retry.DelayType(retry.BackOffDelay), + retry.Delay(time.Second)) + if err != nil { + return nil, err + } + offset += partSize + up(float64(i*100) / float64(count)) + } + } else { + log.Debugf("[aliyundrive_open] rapid upload success, file id: %s", createResp.FileId) + } + + log.Debugf("[aliyundrive_open] create file success, resp: %+v", createResp) + // 3. complete + return d.completeUpload(createResp.FileId, createResp.UploadId) +} diff --git a/drivers/aliyundrive_open/util.go b/drivers/aliyundrive_open/util.go new file mode 100644 index 0000000000000000000000000000000000000000..659d7da72572122690bd1adde585fab58dbd7177 --- /dev/null +++ b/drivers/aliyundrive_open/util.go @@ -0,0 +1,188 @@ +package aliyundrive_open + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +// do others that not defined in Driver interface + +func (d *AliyundriveOpen) _refreshToken() (string, string, error) { + url := API_URL + "/oauth/access_token" + if d.OauthTokenURL != "" && d.ClientID == "" { + url = d.OauthTokenURL + } + //var resp base.TokenResp + var e ErrResp + res, err := base.RestyClient.R(). + //ForceContentType("application/json"). + SetBody(base.Json{ + "client_id": d.ClientID, + "client_secret": d.ClientSecret, + "grant_type": "refresh_token", + "refresh_token": d.RefreshToken, + }). + //SetResult(&resp). + SetError(&e). + Post(url) + if err != nil { + return "", "", err + } + log.Debugf("[ali_open] refresh token response: %s", res.String()) + if e.Code != "" { + return "", "", fmt.Errorf("failed to refresh token: %s", e.Message) + } + refresh, access := utils.Json.Get(res.Body(), "refresh_token").ToString(), utils.Json.Get(res.Body(), "access_token").ToString() + if refresh == "" { + return "", "", fmt.Errorf("failed to refresh token: refresh token is empty, resp: %s", res.String()) + } + curSub, err := getSub(d.RefreshToken) + if err != nil { + return "", "", err + } + newSub, err := getSub(refresh) + if err != nil { + return "", "", err + } + if curSub != newSub { + return "", "", errors.New("failed to refresh token: sub not match") + } + return refresh, access, nil +} + +func getSub(token string) (string, error) { + segments := strings.Split(token, ".") + if len(segments) != 3 { + return "", errors.New("not a jwt token because of invalid segments") + } + bs, err := base64.RawStdEncoding.DecodeString(segments[1]) + if err != nil { + return "", errors.New("failed to decode jwt token") + } + return utils.Json.Get(bs, "sub").ToString(), nil +} + +func (d *AliyundriveOpen) refreshToken() error { + if d.ref != nil { + return d.ref.refreshToken() + } + refresh, access, err := d._refreshToken() + for i := 0; i < 3; i++ { + if err == nil { + break + } else { + log.Errorf("[ali_open] failed to refresh token: %s", err) + } + refresh, access, err = d._refreshToken() + } + if err != nil { + return err + } + log.Infof("[ali_open] token exchange: %s -> %s", d.RefreshToken, refresh) + d.RefreshToken, d.AccessToken = refresh, access + op.MustSaveDriverStorage(d) + return nil +} + +func (d *AliyundriveOpen) request(uri, method string, callback base.ReqCallback, retry ...bool) ([]byte, error) { + b, err, _ := d.requestReturnErrResp(uri, method, callback, retry...) + return b, err +} + +func (d *AliyundriveOpen) requestReturnErrResp(uri, method string, callback base.ReqCallback, retry ...bool) ([]byte, error, *ErrResp) { + req := base.RestyClient.R() + // TODO check whether access_token is expired + req.SetHeader("Authorization", "Bearer "+d.getAccessToken()) + if method == http.MethodPost { + req.SetHeader("Content-Type", "application/json") + } + if callback != nil { + callback(req) + } + var e ErrResp + req.SetError(&e) + res, err := req.Execute(method, API_URL+uri) + if err != nil { + if res != nil { + log.Errorf("[aliyundrive_open] request error: %s", res.String()) + } + return nil, err, nil + } + isRetry := len(retry) > 0 && retry[0] + if e.Code != "" { + if !isRetry && (utils.SliceContains([]string{"AccessTokenInvalid", "AccessTokenExpired", "I400JD"}, e.Code) || d.getAccessToken() == "") { + err = d.refreshToken() + if err != nil { + return nil, err, nil + } + return d.requestReturnErrResp(uri, method, callback, true) + } + return nil, fmt.Errorf("%s:%s", e.Code, e.Message), &e + } + return res.Body(), nil, nil +} + +func (d *AliyundriveOpen) list(ctx context.Context, data base.Json) (*Files, error) { + var resp Files + _, err := d.request("/adrive/v1.0/openFile/list", http.MethodPost, func(req *resty.Request) { + req.SetBody(data).SetResult(&resp) + }) + if err != nil { + return nil, err + } + return &resp, nil +} + +func (d *AliyundriveOpen) getFiles(ctx context.Context, fileId string) ([]File, error) { + marker := "first" + res := make([]File, 0) + for marker != "" { + if marker == "first" { + marker = "" + } + data := base.Json{ + "drive_id": d.DriveId, + "limit": 200, + "marker": marker, + "order_by": d.OrderBy, + "order_direction": d.OrderDirection, + "parent_file_id": fileId, + //"category": "", + //"type": "", + //"video_thumbnail_time": 120000, + //"video_thumbnail_width": 480, + //"image_thumbnail_width": 480, + } + resp, err := d.limitList(ctx, data) + if err != nil { + return nil, err + } + marker = resp.NextMarker + res = append(res, resp.Items...) + } + return res, nil +} + +func getNowTime() (time.Time, string) { + nowTime := time.Now() + nowTimeStr := nowTime.Format("2006-01-02T15:04:05.000Z") + return nowTime, nowTimeStr +} + +func (d *AliyundriveOpen) getAccessToken() string { + if d.ref != nil { + return d.ref.getAccessToken() + } + return d.AccessToken +} diff --git a/drivers/aliyundrive_share/driver.go b/drivers/aliyundrive_share/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..db2ff9ace16e5e5ed12e8bc17f4d9adba7006a4f --- /dev/null +++ b/drivers/aliyundrive_share/driver.go @@ -0,0 +1,147 @@ +package aliyundrive_share + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/Xhofe/rateg" + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/cron" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +type AliyundriveShare struct { + model.Storage + Addition + AccessToken string + ShareToken string + DriveId string + cron *cron.Cron + + limitList func(ctx context.Context, dir model.Obj) ([]model.Obj, error) + limitLink func(ctx context.Context, file model.Obj) (*model.Link, error) +} + +func (d *AliyundriveShare) Config() driver.Config { + return config +} + +func (d *AliyundriveShare) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *AliyundriveShare) Init(ctx context.Context) error { + err := d.refreshToken() + if err != nil { + return err + } + err = d.getShareToken() + if err != nil { + return err + } + d.cron = cron.NewCron(time.Hour * 2) + d.cron.Do(func() { + err := d.refreshToken() + if err != nil { + log.Errorf("%+v", err) + } + }) + d.limitList = rateg.LimitFnCtx(d.list, rateg.LimitFnOption{ + Limit: 4, + Bucket: 1, + }) + d.limitLink = rateg.LimitFnCtx(d.link, rateg.LimitFnOption{ + Limit: 1, + Bucket: 1, + }) + return nil +} + +func (d *AliyundriveShare) Drop(ctx context.Context) error { + if d.cron != nil { + d.cron.Stop() + } + d.DriveId = "" + return nil +} + +func (d *AliyundriveShare) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if d.limitList == nil { + return nil, fmt.Errorf("driver not init") + } + return d.limitList(ctx, dir) +} + +func (d *AliyundriveShare) list(ctx context.Context, dir model.Obj) ([]model.Obj, error) { + files, err := d.getFiles(dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *AliyundriveShare) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if d.limitLink == nil { + return nil, fmt.Errorf("driver not init") + } + return d.limitLink(ctx, file) +} + +func (d *AliyundriveShare) link(ctx context.Context, file model.Obj) (*model.Link, error) { + data := base.Json{ + "drive_id": d.DriveId, + "file_id": file.GetID(), + // // Only ten minutes lifetime + "expire_sec": 600, + "share_id": d.ShareId, + } + var resp ShareLinkResp + _, err := d.request("https://api.alipan.com/v2/file/get_share_link_download_url", http.MethodPost, func(req *resty.Request) { + req.SetHeader(CanaryHeaderKey, CanaryHeaderValue).SetBody(data).SetResult(&resp) + }) + if err != nil { + return nil, err + } + return &model.Link{ + Header: http.Header{ + "Referer": []string{"https://www.alipan.com/"}, + }, + URL: resp.DownloadUrl, + }, nil +} + +func (d *AliyundriveShare) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { + var resp base.Json + var url string + data := base.Json{ + "share_id": d.ShareId, + "file_id": args.Obj.GetID(), + } + switch args.Method { + case "doc_preview": + url = "https://api.alipan.com/v2/file/get_office_preview_url" + case "video_preview": + url = "https://api.alipan.com/v2/file/get_video_preview_play_info" + data["category"] = "live_transcoding" + default: + return nil, errs.NotSupport + } + _, err := d.request(url, http.MethodPost, func(req *resty.Request) { + req.SetBody(data).SetResult(&resp) + }) + if err != nil { + return nil, err + } + return resp, nil +} + +var _ driver.Driver = (*AliyundriveShare)(nil) diff --git a/drivers/aliyundrive_share/meta.go b/drivers/aliyundrive_share/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..acbd6c9bd3aa53d31080c96942b9bf74def5fb56 --- /dev/null +++ b/drivers/aliyundrive_share/meta.go @@ -0,0 +1,29 @@ +package aliyundrive_share + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + RefreshToken string `json:"refresh_token" required:"true"` + ShareId string `json:"share_id" required:"true"` + SharePwd string `json:"share_pwd"` + driver.RootID + OrderBy string `json:"order_by" type:"select" options:"name,size,updated_at,created_at"` + OrderDirection string `json:"order_direction" type:"select" options:"ASC,DESC"` +} + +var config = driver.Config{ + Name: "AliyundriveShare", + LocalSort: false, + OnlyProxy: false, + NoUpload: true, + DefaultRoot: "root", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &AliyundriveShare{} + }) +} diff --git a/drivers/aliyundrive_share/types.go b/drivers/aliyundrive_share/types.go new file mode 100644 index 0000000000000000000000000000000000000000..bb9be800e51e1b25dc1602ac13122721a0eae11d --- /dev/null +++ b/drivers/aliyundrive_share/types.go @@ -0,0 +1,58 @@ +package aliyundrive_share + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type ErrorResp struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type ShareTokenResp struct { + ShareToken string `json:"share_token"` + ExpireTime time.Time `json:"expire_time"` + ExpiresIn int `json:"expires_in"` +} + +type ListResp struct { + Items []File `json:"items"` + NextMarker string `json:"next_marker"` + PunishedFileCount int `json:"punished_file_count"` +} + +type File struct { + DriveId string `json:"drive_id"` + DomainId string `json:"domain_id"` + FileId string `json:"file_id"` + ShareId string `json:"share_id"` + Name string `json:"name"` + Type string `json:"type"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + ParentFileId string `json:"parent_file_id"` + Size int64 `json:"size"` + Thumbnail string `json:"thumbnail"` +} + +func fileToObj(f File) *model.ObjThumb { + return &model.ObjThumb{ + Object: model.Object{ + ID: f.FileId, + Name: f.Name, + Size: f.Size, + Modified: f.UpdatedAt, + Ctime: f.CreatedAt, + IsFolder: f.Type == "folder", + }, + Thumbnail: model.Thumbnail{Thumbnail: f.Thumbnail}, + } +} + +type ShareLinkResp struct { + DownloadUrl string `json:"download_url"` + Url string `json:"url"` + Thumbnail string `json:"thumbnail"` +} diff --git a/drivers/aliyundrive_share/util.go b/drivers/aliyundrive_share/util.go new file mode 100644 index 0000000000000000000000000000000000000000..899e15cec1b113b38b18b7b52b97b64606efee32 --- /dev/null +++ b/drivers/aliyundrive_share/util.go @@ -0,0 +1,141 @@ +package aliyundrive_share + +import ( + "errors" + "fmt" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/op" + log "github.com/sirupsen/logrus" +) + +const ( + // CanaryHeaderKey CanaryHeaderValue for lifting rate limit restrictions + CanaryHeaderKey = "X-Canary" + CanaryHeaderValue = "client=web,app=share,version=v2.3.1" +) + +func (d *AliyundriveShare) refreshToken() error { + url := "https://auth.alipan.com/v2/account/token" + var resp base.TokenResp + var e ErrorResp + _, err := base.RestyClient.R(). + SetBody(base.Json{"refresh_token": d.RefreshToken, "grant_type": "refresh_token"}). + SetResult(&resp). + SetError(&e). + Post(url) + if err != nil { + return err + } + if e.Code != "" { + return fmt.Errorf("failed to refresh token: %s", e.Message) + } + d.RefreshToken, d.AccessToken = resp.RefreshToken, resp.AccessToken + op.MustSaveDriverStorage(d) + return nil +} + +// do others that not defined in Driver interface +func (d *AliyundriveShare) getShareToken() error { + data := base.Json{ + "share_id": d.ShareId, + } + if d.SharePwd != "" { + data["share_pwd"] = d.SharePwd + } + var e ErrorResp + var resp ShareTokenResp + _, err := base.RestyClient.R(). + SetResult(&resp).SetError(&e).SetBody(data). + Post("https://api.alipan.com/v2/share_link/get_share_token") + if err != nil { + return err + } + if e.Code != "" { + return errors.New(e.Message) + } + d.ShareToken = resp.ShareToken + return nil +} + +func (d *AliyundriveShare) request(url, method string, callback base.ReqCallback) ([]byte, error) { + var e ErrorResp + req := base.RestyClient.R(). + SetError(&e). + SetHeader("content-type", "application/json"). + SetHeader("Authorization", "Bearer\t"+d.AccessToken). + SetHeader(CanaryHeaderKey, CanaryHeaderValue). + SetHeader("x-share-token", d.ShareToken) + if callback != nil { + callback(req) + } else { + req.SetBody("{}") + } + resp, err := req.Execute(method, url) + if err != nil { + return nil, err + } + if e.Code != "" { + if e.Code == "AccessTokenInvalid" || e.Code == "ShareLinkTokenInvalid" { + if e.Code == "AccessTokenInvalid" { + err = d.refreshToken() + } else { + err = d.getShareToken() + } + if err != nil { + return nil, err + } + return d.request(url, method, callback) + } else { + return nil, errors.New(e.Code + ": " + e.Message) + } + } + return resp.Body(), nil +} + +func (d *AliyundriveShare) getFiles(fileId string) ([]File, error) { + files := make([]File, 0) + data := base.Json{ + "image_thumbnail_process": "image/resize,w_160/format,jpeg", + "image_url_process": "image/resize,w_1920/format,jpeg", + "limit": 200, + "order_by": d.OrderBy, + "order_direction": d.OrderDirection, + "parent_file_id": fileId, + "share_id": d.ShareId, + "video_thumbnail_process": "video/snapshot,t_1000,f_jpg,ar_auto,w_300", + "marker": "first", + } + for data["marker"] != "" { + if data["marker"] == "first" { + data["marker"] = "" + } + var e ErrorResp + var resp ListResp + res, err := base.RestyClient.R(). + SetHeader("x-share-token", d.ShareToken). + SetHeader(CanaryHeaderKey, CanaryHeaderValue). + SetResult(&resp).SetError(&e).SetBody(data). + Post("https://api.alipan.com/adrive/v3/file/list") + if err != nil { + return nil, err + } + log.Debugf("aliyundrive share get files: %s", res.String()) + if e.Code != "" { + if e.Code == "AccessTokenInvalid" || e.Code == "ShareLinkTokenInvalid" { + err = d.getShareToken() + if err != nil { + return nil, err + } + return d.getFiles(fileId) + } + return nil, errors.New(e.Message) + } + data["marker"] = resp.NextMarker + files = append(files, resp.Items...) + } + if len(files) > 0 && d.DriveId == "" { + d.DriveId = files[0].DriveId + } + return files, nil +} diff --git a/drivers/all.go b/drivers/all.go new file mode 100644 index 0000000000000000000000000000000000000000..d9997f15680acd97c286f5923a48d66b7d44abf3 --- /dev/null +++ b/drivers/all.go @@ -0,0 +1,73 @@ +package drivers + +import ( + _ "github.com/alist-org/alist/v3/drivers/115" + _ "github.com/alist-org/alist/v3/drivers/115_share" + _ "github.com/alist-org/alist/v3/drivers/123" + _ "github.com/alist-org/alist/v3/drivers/123_link" + _ "github.com/alist-org/alist/v3/drivers/123_share" + _ "github.com/alist-org/alist/v3/drivers/139" + _ "github.com/alist-org/alist/v3/drivers/189" + _ "github.com/alist-org/alist/v3/drivers/189pc" + _ "github.com/alist-org/alist/v3/drivers/alias" + _ "github.com/alist-org/alist/v3/drivers/alist_v2" + _ "github.com/alist-org/alist/v3/drivers/alist_v3" + _ "github.com/alist-org/alist/v3/drivers/aliyundrive" + _ "github.com/alist-org/alist/v3/drivers/aliyundrive_open" + _ "github.com/alist-org/alist/v3/drivers/aliyundrive_share" + _ "github.com/alist-org/alist/v3/drivers/baidu_netdisk" + _ "github.com/alist-org/alist/v3/drivers/baidu_photo" + _ "github.com/alist-org/alist/v3/drivers/baidu_share" + _ "github.com/alist-org/alist/v3/drivers/chaoxing" + _ "github.com/alist-org/alist/v3/drivers/cloudreve" + _ "github.com/alist-org/alist/v3/drivers/crypt" + _ "github.com/alist-org/alist/v3/drivers/dropbox" + _ "github.com/alist-org/alist/v3/drivers/febbox" + _ "github.com/alist-org/alist/v3/drivers/ftp" + _ "github.com/alist-org/alist/v3/drivers/google_drive" + _ "github.com/alist-org/alist/v3/drivers/google_photo" + _ "github.com/alist-org/alist/v3/drivers/halalcloud" + _ "github.com/alist-org/alist/v3/drivers/ilanzou" + _ "github.com/alist-org/alist/v3/drivers/ipfs_api" + _ "github.com/alist-org/alist/v3/drivers/kodbox" + _ "github.com/alist-org/alist/v3/drivers/lanzou" + _ "github.com/alist-org/alist/v3/drivers/lenovonas_share" + _ "github.com/alist-org/alist/v3/drivers/local" + _ "github.com/alist-org/alist/v3/drivers/mediatrack" + _ "github.com/alist-org/alist/v3/drivers/mega" + _ "github.com/alist-org/alist/v3/drivers/mopan" + _ "github.com/alist-org/alist/v3/drivers/netease_music" + _ "github.com/alist-org/alist/v3/drivers/onedrive" + _ "github.com/alist-org/alist/v3/drivers/onedrive_app" + _ "github.com/alist-org/alist/v3/drivers/onedrive_sharelink" + _ "github.com/alist-org/alist/v3/drivers/pikpak" + _ "github.com/alist-org/alist/v3/drivers/pikpak_proxy" + _ "github.com/alist-org/alist/v3/drivers/pikpak_share" + _ "github.com/alist-org/alist/v3/drivers/quark_uc" + _ "github.com/alist-org/alist/v3/drivers/quark_uc_tv" + _ "github.com/alist-org/alist/v3/drivers/quqi" + _ "github.com/alist-org/alist/v3/drivers/s3" + _ "github.com/alist-org/alist/v3/drivers/seafile" + _ "github.com/alist-org/alist/v3/drivers/sftp" + _ "github.com/alist-org/alist/v3/drivers/smb" + _ "github.com/alist-org/alist/v3/drivers/teambition" + _ "github.com/alist-org/alist/v3/drivers/terabox" + _ "github.com/alist-org/alist/v3/drivers/thunder" + _ "github.com/alist-org/alist/v3/drivers/thunder_browser" + _ "github.com/alist-org/alist/v3/drivers/thunderx" + _ "github.com/alist-org/alist/v3/drivers/trainbit" + _ "github.com/alist-org/alist/v3/drivers/url_tree" + _ "github.com/alist-org/alist/v3/drivers/uss" + _ "github.com/alist-org/alist/v3/drivers/virtual" + _ "github.com/alist-org/alist/v3/drivers/vtencent" + _ "github.com/alist-org/alist/v3/drivers/webdav" + _ "github.com/alist-org/alist/v3/drivers/weiyun" + _ "github.com/alist-org/alist/v3/drivers/wopan" + _ "github.com/alist-org/alist/v3/drivers/yandex_disk" +) + +// All do nothing,just for import +// same as _ import +func All() { + +} diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..ad52a4b54384fad59720a0668ee4d53a1ded3c69 --- /dev/null +++ b/drivers/baidu_netdisk/driver.go @@ -0,0 +1,332 @@ +package baidu_netdisk + +import ( + "context" + "crypto/md5" + "encoding/hex" + "errors" + "io" + "math" + "net/url" + stdpath "path" + "strconv" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/errgroup" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/avast/retry-go" + log "github.com/sirupsen/logrus" +) + +type BaiduNetdisk struct { + model.Storage + Addition + + uploadThread int + vipType int // 会员类型,0普通用户(4G/4M)、1普通会员(10G/16M)、2超级会员(20G/32M) +} + +func (d *BaiduNetdisk) Config() driver.Config { + return config +} + +func (d *BaiduNetdisk) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *BaiduNetdisk) Init(ctx context.Context) error { + d.uploadThread, _ = strconv.Atoi(d.UploadThread) + if d.uploadThread < 1 || d.uploadThread > 32 { + d.uploadThread, d.UploadThread = 3, "3" + } + + if _, err := url.Parse(d.UploadAPI); d.UploadAPI == "" || err != nil { + d.UploadAPI = "https://d.pcs.baidu.com" + } + + res, err := d.get("/xpan/nas", map[string]string{ + "method": "uinfo", + }, nil) + log.Debugf("[baidu] get uinfo: %s", string(res)) + if err != nil { + return err + } + d.vipType = utils.Json.Get(res, "vip_type").ToInt() + return nil +} + +func (d *BaiduNetdisk) Drop(ctx context.Context) error { + return nil +} + +func (d *BaiduNetdisk) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(dir.GetPath()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *BaiduNetdisk) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if d.DownloadAPI == "crack" { + return d.linkCrack(file, args) + } + return d.linkOfficial(file, args) +} + +func (d *BaiduNetdisk) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + var newDir File + _, err := d.create(stdpath.Join(parentDir.GetPath(), dirName), 0, 1, "", "", &newDir, 0, 0) + if err != nil { + return nil, err + } + return fileToObj(newDir), nil +} + +func (d *BaiduNetdisk) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + data := []base.Json{ + { + "path": srcObj.GetPath(), + "dest": dstDir.GetPath(), + "newname": srcObj.GetName(), + }, + } + _, err := d.manage("move", data) + if err != nil { + return nil, err + } + if srcObj, ok := srcObj.(*model.ObjThumb); ok { + srcObj.SetPath(stdpath.Join(dstDir.GetPath(), srcObj.GetName())) + srcObj.Modified = time.Now() + return srcObj, nil + } + return nil, nil +} + +func (d *BaiduNetdisk) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + data := []base.Json{ + { + "path": srcObj.GetPath(), + "newname": newName, + }, + } + _, err := d.manage("rename", data) + if err != nil { + return nil, err + } + + if srcObj, ok := srcObj.(*model.ObjThumb); ok { + srcObj.SetPath(stdpath.Join(stdpath.Dir(srcObj.GetPath()), newName)) + srcObj.Name = newName + srcObj.Modified = time.Now() + return srcObj, nil + } + return nil, nil +} + +func (d *BaiduNetdisk) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + data := []base.Json{ + { + "path": srcObj.GetPath(), + "dest": dstDir.GetPath(), + "newname": srcObj.GetName(), + }, + } + _, err := d.manage("copy", data) + return err +} + +func (d *BaiduNetdisk) Remove(ctx context.Context, obj model.Obj) error { + data := []string{obj.GetPath()} + _, err := d.manage("delete", data) + return err +} + +func (d *BaiduNetdisk) PutRapid(ctx context.Context, dstDir model.Obj, stream model.FileStreamer) (model.Obj, error) { + contentMd5 := stream.GetHash().GetHash(utils.MD5) + if len(contentMd5) < utils.MD5.Width { + return nil, errors.New("invalid hash") + } + + streamSize := stream.GetSize() + path := stdpath.Join(dstDir.GetPath(), stream.GetName()) + mtime := stream.ModTime().Unix() + ctime := stream.CreateTime().Unix() + blockList, _ := utils.Json.MarshalToString([]string{contentMd5}) + + var newFile File + _, err := d.create(path, streamSize, 0, "", blockList, &newFile, mtime, ctime) + if err != nil { + return nil, err + } + // 修复时间,具体原因见 Put 方法注释的 **注意** + newFile.Ctime = stream.CreateTime().Unix() + newFile.Mtime = stream.ModTime().Unix() + return fileToObj(newFile), nil +} + +// Put +// +// **注意**: 截至 2024/04/20 百度云盘 api 接口返回的时间永远是当前时间,而不是文件时间。 +// 而实际上云盘存储的时间是文件时间,所以此处需要覆盖时间,保证缓存与云盘的数据一致 +func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + // rapid upload + if newObj, err := d.PutRapid(ctx, dstDir, stream); err == nil { + return newObj, nil + } + + tempFile, err := stream.CacheFullInTempFile() + if err != nil { + return nil, err + } + + streamSize := stream.GetSize() + sliceSize := d.getSliceSize() + count := int(math.Max(math.Ceil(float64(streamSize)/float64(sliceSize)), 1)) + lastBlockSize := streamSize % sliceSize + if streamSize > 0 && lastBlockSize == 0 { + lastBlockSize = sliceSize + } + + //cal md5 for first 256k data + const SliceSize int64 = 256 * 1024 + // cal md5 + blockList := make([]string, 0, count) + byteSize := sliceSize + fileMd5H := md5.New() + sliceMd5H := md5.New() + sliceMd5H2 := md5.New() + slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize) + + for i := 1; i <= count; i++ { + if utils.IsCanceled(ctx) { + return nil, ctx.Err() + } + if i == count { + byteSize = lastBlockSize + } + _, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5H, sliceMd5H, slicemd5H2Write), tempFile, byteSize) + if err != nil && err != io.EOF { + return nil, err + } + blockList = append(blockList, hex.EncodeToString(sliceMd5H.Sum(nil))) + sliceMd5H.Reset() + } + contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil)) + sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil)) + blockListStr, _ := utils.Json.MarshalToString(blockList) + path := stdpath.Join(dstDir.GetPath(), stream.GetName()) + mtime := stream.ModTime().Unix() + ctime := stream.CreateTime().Unix() + + // step.1 预上传 + // 尝试获取之前的进度 + precreateResp, ok := base.GetUploadProgress[*PrecreateResp](d, d.AccessToken, contentMd5) + if !ok { + params := map[string]string{ + "method": "precreate", + } + form := map[string]string{ + "path": path, + "size": strconv.FormatInt(streamSize, 10), + "isdir": "0", + "autoinit": "1", + "rtype": "3", + "block_list": blockListStr, + "content-md5": contentMd5, + "slice-md5": sliceMd5, + } + joinTime(form, ctime, mtime) + + log.Debugf("[baidu_netdisk] precreate data: %s", form) + _, err = d.postForm("/xpan/file", params, form, &precreateResp) + if err != nil { + return nil, err + } + log.Debugf("%+v", precreateResp) + if precreateResp.ReturnType == 2 { + //rapid upload, since got md5 match from baidu server + // 修复时间,具体原因见 Put 方法注释的 **注意** + precreateResp.File.Ctime = ctime + precreateResp.File.Mtime = mtime + return fileToObj(precreateResp.File), nil + } + } + // step.2 上传分片 + threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread, + retry.Attempts(3), + retry.Delay(time.Second), + retry.DelayType(retry.BackOffDelay)) + for i, partseq := range precreateResp.BlockList { + if utils.IsCanceled(upCtx) { + break + } + + i, partseq, offset, byteSize := i, partseq, int64(partseq)*sliceSize, sliceSize + if partseq+1 == count { + byteSize = lastBlockSize + } + threadG.Go(func(ctx context.Context) error { + params := map[string]string{ + "method": "upload", + "access_token": d.AccessToken, + "type": "tmpfile", + "path": path, + "uploadid": precreateResp.Uploadid, + "partseq": strconv.Itoa(partseq), + } + err := d.uploadSlice(ctx, params, stream.GetName(), io.NewSectionReader(tempFile, offset, byteSize)) + if err != nil { + return err + } + up(float64(threadG.Success()) * 100 / float64(len(precreateResp.BlockList))) + precreateResp.BlockList[i] = -1 + return nil + }) + } + if err = threadG.Wait(); err != nil { + // 如果属于用户主动取消,则保存上传进度 + if errors.Is(err, context.Canceled) { + precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 }) + base.SaveUploadProgress(d, precreateResp, d.AccessToken, contentMd5) + } + return nil, err + } + + // step.3 创建文件 + var newFile File + _, err = d.create(path, streamSize, 0, precreateResp.Uploadid, blockListStr, &newFile, mtime, ctime) + if err != nil { + return nil, err + } + // 修复时间,具体原因见 Put 方法注释的 **注意** + newFile.Ctime = ctime + newFile.Mtime = mtime + return fileToObj(newFile), nil +} + +func (d *BaiduNetdisk) uploadSlice(ctx context.Context, params map[string]string, fileName string, file io.Reader) error { + res, err := base.RestyClient.R(). + SetContext(ctx). + SetQueryParams(params). + SetFileReader("file", fileName, file). + Post(d.UploadAPI + "/rest/2.0/pcs/superfile2") + if err != nil { + return err + } + log.Debugln(res.RawResponse.Status + res.String()) + errCode := utils.Json.Get(res.Body(), "error_code").ToInt() + errNo := utils.Json.Get(res.Body(), "errno").ToInt() + if errCode != 0 || errNo != 0 { + return errs.NewErr(errs.StreamIncomplete, "error in uploading to baidu, will retry. response=%s", res.String()) + } + return nil +} + +var _ driver.Driver = (*BaiduNetdisk)(nil) diff --git a/drivers/baidu_netdisk/meta.go b/drivers/baidu_netdisk/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..bf2aed5a2b5a9c816717ed62f333ac9cf0da3c9f --- /dev/null +++ b/drivers/baidu_netdisk/meta.go @@ -0,0 +1,32 @@ +package baidu_netdisk + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + RefreshToken string `json:"refresh_token" required:"true"` + driver.RootPath + OrderBy string `json:"order_by" type:"select" options:"name,time,size" default:"name"` + OrderDirection string `json:"order_direction" type:"select" options:"asc,desc" default:"asc"` + DownloadAPI string `json:"download_api" type:"select" options:"official,crack" default:"official"` + ClientID string `json:"client_id" required:"true" default:"iYCeC9g08h5vuP9UqvPHKKSVrKFXGa1v"` + ClientSecret string `json:"client_secret" required:"true" default:"jXiFMOPVPCWlO2M5CwWQzffpNPaGTRBG"` + CustomCrackUA string `json:"custom_crack_ua" required:"true" default:"netdisk"` + AccessToken string + UploadThread string `json:"upload_thread" default:"3" help:"1<=thread<=32"` + UploadAPI string `json:"upload_api" default:"https://d.pcs.baidu.com"` + CustomUploadPartSize int64 `json:"custom_upload_part_size" type:"number" default:"0" help:"0 for auto"` +} + +var config = driver.Config{ + Name: "BaiduNetdisk", + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &BaiduNetdisk{} + }) +} diff --git a/drivers/baidu_netdisk/types.go b/drivers/baidu_netdisk/types.go new file mode 100644 index 0000000000000000000000000000000000000000..728273b8dab3a30646f32fad98a5d7dac5162386 --- /dev/null +++ b/drivers/baidu_netdisk/types.go @@ -0,0 +1,191 @@ +package baidu_netdisk + +import ( + "path" + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" +) + +type TokenErrResp struct { + ErrorDescription string `json:"error_description"` + Error string `json:"error"` +} + +type File struct { + //TkbindId int `json:"tkbind_id"` + //OwnerType int `json:"owner_type"` + //Category int `json:"category"` + //RealCategory string `json:"real_category"` + FsId int64 `json:"fs_id"` + //OperId int `json:"oper_id"` + Thumbs struct { + //Icon string `json:"icon"` + Url3 string `json:"url3"` + //Url2 string `json:"url2"` + //Url1 string `json:"url1"` + } `json:"thumbs"` + //Wpfile int `json:"wpfile"` + + Size int64 `json:"size"` + //ExtentTinyint7 int `json:"extent_tinyint7"` + Path string `json:"path"` + //Share int `json:"share"` + //Pl int `json:"pl"` + ServerFilename string `json:"server_filename"` + Md5 string `json:"md5"` + //OwnerId int `json:"owner_id"` + //Unlist int `json:"unlist"` + Isdir int `json:"isdir"` + + // list resp + ServerCtime int64 `json:"server_ctime"` + ServerMtime int64 `json:"server_mtime"` + LocalMtime int64 `json:"local_mtime"` + LocalCtime int64 `json:"local_ctime"` + //ServerAtime int64 `json:"server_atime"` ` + + // only create and precreate resp + Ctime int64 `json:"ctime"` + Mtime int64 `json:"mtime"` +} + +func fileToObj(f File) *model.ObjThumb { + if f.ServerFilename == "" { + f.ServerFilename = path.Base(f.Path) + } + if f.ServerCtime == 0 { + f.ServerCtime = f.Ctime + } + if f.ServerMtime == 0 { + f.ServerMtime = f.Mtime + } + return &model.ObjThumb{ + Object: model.Object{ + ID: strconv.FormatInt(f.FsId, 10), + Path: f.Path, + Name: f.ServerFilename, + Size: f.Size, + Modified: time.Unix(f.ServerMtime, 0), + Ctime: time.Unix(f.ServerCtime, 0), + IsFolder: f.Isdir == 1, + + // 直接获取的MD5是错误的 + HashInfo: utils.NewHashInfo(utils.MD5, DecryptMd5(f.Md5)), + }, + Thumbnail: model.Thumbnail{Thumbnail: f.Thumbs.Url3}, + } +} + +type ListResp struct { + Errno int `json:"errno"` + GuidInfo string `json:"guid_info"` + List []File `json:"list"` + RequestId int64 `json:"request_id"` + Guid int `json:"guid"` +} + +type DownloadResp struct { + Errmsg string `json:"errmsg"` + Errno int `json:"errno"` + List []struct { + //Category int `json:"category"` + //DateTaken int `json:"date_taken,omitempty"` + Dlink string `json:"dlink"` + //Filename string `json:"filename"` + //FsId int64 `json:"fs_id"` + //Height int `json:"height,omitempty"` + //Isdir int `json:"isdir"` + //Md5 string `json:"md5"` + //OperId int `json:"oper_id"` + //Path string `json:"path"` + //ServerCtime int `json:"server_ctime"` + //ServerMtime int `json:"server_mtime"` + //Size int `json:"size"` + //Thumbs struct { + // Icon string `json:"icon,omitempty"` + // Url1 string `json:"url1,omitempty"` + // Url2 string `json:"url2,omitempty"` + // Url3 string `json:"url3,omitempty"` + //} `json:"thumbs"` + //Width int `json:"width,omitempty"` + } `json:"list"` + //Names struct { + //} `json:"names"` + RequestId string `json:"request_id"` +} + +type DownloadResp2 struct { + Errno int `json:"errno"` + Info []struct { + //ExtentTinyint4 int `json:"extent_tinyint4"` + //ExtentTinyint1 int `json:"extent_tinyint1"` + //Bitmap string `json:"bitmap"` + //Category int `json:"category"` + //Isdir int `json:"isdir"` + //Videotag int `json:"videotag"` + Dlink string `json:"dlink"` + //OperID int64 `json:"oper_id"` + //PathMd5 int `json:"path_md5"` + //Wpfile int `json:"wpfile"` + //LocalMtime int `json:"local_mtime"` + /*Thumbs struct { + Icon string `json:"icon"` + URL3 string `json:"url3"` + URL2 string `json:"url2"` + URL1 string `json:"url1"` + } `json:"thumbs"`*/ + //PlaySource int `json:"play_source"` + //Share int `json:"share"` + //FileKey string `json:"file_key"` + //Errno int `json:"errno"` + //LocalCtime int `json:"local_ctime"` + //Rotate int `json:"rotate"` + //Metadata time.Time `json:"metadata"` + //Height int `json:"height"` + //SampleRate int `json:"sample_rate"` + //Width int `json:"width"` + //OwnerType int `json:"owner_type"` + //Privacy int `json:"privacy"` + //ExtentInt3 int64 `json:"extent_int3"` + //RealCategory string `json:"real_category"` + //SrcLocation string `json:"src_location"` + //MetaInfo string `json:"meta_info"` + //ID string `json:"id"` + //Duration int `json:"duration"` + //FileSize string `json:"file_size"` + //Channels int `json:"channels"` + //UseSegment int `json:"use_segment"` + //ServerCtime int `json:"server_ctime"` + //Resolution string `json:"resolution"` + //OwnerID int `json:"owner_id"` + //ExtraInfo string `json:"extra_info"` + //Size int `json:"size"` + //FsID int64 `json:"fs_id"` + //ExtentTinyint3 int `json:"extent_tinyint3"` + //Md5 string `json:"md5"` + //Path string `json:"path"` + //FrameRate int `json:"frame_rate"` + //ExtentTinyint2 int `json:"extent_tinyint2"` + //ServerFilename string `json:"server_filename"` + //ServerMtime int `json:"server_mtime"` + //TkbindID int `json:"tkbind_id"` + } `json:"info"` + RequestID int64 `json:"request_id"` +} + +type PrecreateResp struct { + Errno int `json:"errno"` + RequestId int64 `json:"request_id"` + ReturnType int `json:"return_type"` + + // return_type=1 + Path string `json:"path"` + Uploadid string `json:"uploadid"` + BlockList []int `json:"block_list"` + + // return_type=2 + File File `json:"info"` +} diff --git a/drivers/baidu_netdisk/util.go b/drivers/baidu_netdisk/util.go new file mode 100644 index 0000000000000000000000000000000000000000..ca1a6805a044307bcca9b2396786b05f730793ec --- /dev/null +++ b/drivers/baidu_netdisk/util.go @@ -0,0 +1,292 @@ +package baidu_netdisk + +import ( + "encoding/hex" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "time" + "unicode" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/avast/retry-go" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +// do others that not defined in Driver interface + +func (d *BaiduNetdisk) refreshToken() error { + err := d._refreshToken() + if err != nil && errors.Is(err, errs.EmptyToken) { + err = d._refreshToken() + } + return err +} + +func (d *BaiduNetdisk) _refreshToken() error { + u := "https://openapi.baidu.com/oauth/2.0/token" + var resp base.TokenResp + var e TokenErrResp + _, err := base.RestyClient.R().SetResult(&resp).SetError(&e).SetQueryParams(map[string]string{ + "grant_type": "refresh_token", + "refresh_token": d.RefreshToken, + "client_id": d.ClientID, + "client_secret": d.ClientSecret, + }).Get(u) + if err != nil { + return err + } + if e.Error != "" { + return fmt.Errorf("%s : %s", e.Error, e.ErrorDescription) + } + if resp.RefreshToken == "" { + return errs.EmptyToken + } + d.AccessToken, d.RefreshToken = resp.AccessToken, resp.RefreshToken + op.MustSaveDriverStorage(d) + return nil +} + +func (d *BaiduNetdisk) request(furl string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + var result []byte + err := retry.Do(func() error { + req := base.RestyClient.R() + req.SetQueryParam("access_token", d.AccessToken) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + res, err := req.Execute(method, furl) + if err != nil { + return err + } + log.Debugf("[baidu_netdisk] req: %s, resp: %s", furl, res.String()) + errno := utils.Json.Get(res.Body(), "errno").ToInt() + if errno != 0 { + if utils.SliceContains([]int{111, -6}, errno) { + log.Info("refreshing baidu_netdisk token.") + err2 := d.refreshToken() + if err2 != nil { + return retry.Unrecoverable(err2) + } + } + return fmt.Errorf("req: [%s] ,errno: %d, refer to https://pan.baidu.com/union/doc/", furl, errno) + } + result = res.Body() + return nil + }, + retry.LastErrorOnly(true), + retry.Attempts(3), + retry.Delay(time.Second), + retry.DelayType(retry.BackOffDelay)) + return result, err +} + +func (d *BaiduNetdisk) get(pathname string, params map[string]string, resp interface{}) ([]byte, error) { + return d.request("https://pan.baidu.com/rest/2.0"+pathname, http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(params) + }, resp) +} + +func (d *BaiduNetdisk) postForm(pathname string, params map[string]string, form map[string]string, resp interface{}) ([]byte, error) { + return d.request("https://pan.baidu.com/rest/2.0"+pathname, http.MethodPost, func(req *resty.Request) { + req.SetQueryParams(params) + req.SetFormData(form) + }, resp) +} + +func (d *BaiduNetdisk) getFiles(dir string) ([]File, error) { + start := 0 + limit := 200 + params := map[string]string{ + "method": "list", + "dir": dir, + "web": "web", + } + if d.OrderBy != "" { + params["order"] = d.OrderBy + if d.OrderDirection == "desc" { + params["desc"] = "1" + } + } + res := make([]File, 0) + for { + params["start"] = strconv.Itoa(start) + params["limit"] = strconv.Itoa(limit) + start += limit + var resp ListResp + _, err := d.get("/xpan/file", params, &resp) + if err != nil { + return nil, err + } + if len(resp.List) == 0 { + break + } + res = append(res, resp.List...) + } + return res, nil +} + +func (d *BaiduNetdisk) linkOfficial(file model.Obj, args model.LinkArgs) (*model.Link, error) { + var resp DownloadResp + params := map[string]string{ + "method": "filemetas", + "fsids": fmt.Sprintf("[%s]", file.GetID()), + "dlink": "1", + } + _, err := d.get("/xpan/multimedia", params, &resp) + if err != nil { + return nil, err + } + u := fmt.Sprintf("%s&access_token=%s", resp.List[0].Dlink, d.AccessToken) + res, err := base.NoRedirectClient.R().SetHeader("User-Agent", "pan.baidu.com").Head(u) + if err != nil { + return nil, err + } + //if res.StatusCode() == 302 { + u = res.Header().Get("location") + //} + + return &model.Link{ + URL: u, + Header: http.Header{ + "User-Agent": []string{"pan.baidu.com"}, + }, + }, nil +} + +func (d *BaiduNetdisk) linkCrack(file model.Obj, args model.LinkArgs) (*model.Link, error) { + var resp DownloadResp2 + param := map[string]string{ + "target": fmt.Sprintf("[\"%s\"]", file.GetPath()), + "dlink": "1", + "web": "5", + "origin": "dlna", + } + _, err := d.request("https://pan.baidu.com/api/filemetas", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(param) + }, &resp) + if err != nil { + return nil, err + } + + return &model.Link{ + URL: resp.Info[0].Dlink, + Header: http.Header{ + "User-Agent": []string{d.CustomCrackUA}, + }, + }, nil +} + +func (d *BaiduNetdisk) manage(opera string, filelist any) ([]byte, error) { + params := map[string]string{ + "method": "filemanager", + "opera": opera, + } + marshal, _ := utils.Json.MarshalToString(filelist) + return d.postForm("/xpan/file", params, map[string]string{ + "async": "0", + "filelist": marshal, + "ondup": "fail", + }, nil) +} + +func (d *BaiduNetdisk) create(path string, size int64, isdir int, uploadid, block_list string, resp any, mtime, ctime int64) ([]byte, error) { + params := map[string]string{ + "method": "create", + } + form := map[string]string{ + "path": path, + "size": strconv.FormatInt(size, 10), + "isdir": strconv.Itoa(isdir), + "rtype": "3", + } + if mtime != 0 && ctime != 0 { + joinTime(form, ctime, mtime) + } + + if uploadid != "" { + form["uploadid"] = uploadid + } + if block_list != "" { + form["block_list"] = block_list + } + return d.postForm("/xpan/file", params, form, resp) +} + +func joinTime(form map[string]string, ctime, mtime int64) { + form["local_mtime"] = strconv.FormatInt(mtime, 10) + form["local_ctime"] = strconv.FormatInt(ctime, 10) +} + +const ( + DefaultSliceSize int64 = 4 * utils.MB + VipSliceSize = 16 * utils.MB + SVipSliceSize = 32 * utils.MB +) + +func (d *BaiduNetdisk) getSliceSize() int64 { + if d.CustomUploadPartSize != 0 { + return d.CustomUploadPartSize + } + switch d.vipType { + case 1: + return VipSliceSize + case 2: + return SVipSliceSize + default: + return DefaultSliceSize + } +} + +// func encodeURIComponent(str string) string { +// r := url.QueryEscape(str) +// r = strings.ReplaceAll(r, "+", "%20") +// return r +// } + +func DecryptMd5(encryptMd5 string) string { + if _, err := hex.DecodeString(encryptMd5); err == nil { + return encryptMd5 + } + + var out strings.Builder + out.Grow(len(encryptMd5)) + for i, n := 0, int64(0); i < len(encryptMd5); i++ { + if i == 9 { + n = int64(unicode.ToLower(rune(encryptMd5[i])) - 'g') + } else { + n, _ = strconv.ParseInt(encryptMd5[i:i+1], 16, 64) + } + out.WriteString(strconv.FormatInt(n^int64(15&i), 16)) + } + + encryptMd5 = out.String() + return encryptMd5[8:16] + encryptMd5[:8] + encryptMd5[24:32] + encryptMd5[16:24] +} + +func EncryptMd5(originalMd5 string) string { + reversed := originalMd5[8:16] + originalMd5[:8] + originalMd5[24:32] + originalMd5[16:24] + + var out strings.Builder + out.Grow(len(reversed)) + for i, n := 0, int64(0); i < len(reversed); i++ { + n, _ = strconv.ParseInt(reversed[i:i+1], 16, 64) + n ^= int64(15 & i) + if i == 9 { + out.WriteRune(rune(n) + 'g') + } else { + out.WriteString(strconv.FormatInt(n, 16)) + } + } + return out.String() +} diff --git a/drivers/baidu_photo/driver.go b/drivers/baidu_photo/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..d0d69e822222e925d93d14ad59e8b7079d2d7cc2 --- /dev/null +++ b/drivers/baidu_photo/driver.go @@ -0,0 +1,377 @@ +package baiduphoto + +import ( + "context" + "crypto/md5" + "encoding/hex" + "errors" + "fmt" + "io" + "math" + "regexp" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/errgroup" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/avast/retry-go" + "github.com/go-resty/resty/v2" +) + +type BaiduPhoto struct { + model.Storage + Addition + + // AccessToken string + Uk int64 + root model.Obj + + uploadThread int +} + +func (d *BaiduPhoto) Config() driver.Config { + return config +} + +func (d *BaiduPhoto) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *BaiduPhoto) Init(ctx context.Context) error { + d.uploadThread, _ = strconv.Atoi(d.UploadThread) + if d.uploadThread < 1 || d.uploadThread > 32 { + d.uploadThread, d.UploadThread = 3, "3" + } + + // if err := d.refreshToken(); err != nil { + // return err + // } + + // root + if d.AlbumID != "" { + albumID := strings.Split(d.AlbumID, "|")[0] + album, err := d.GetAlbumDetail(ctx, albumID) + if err != nil { + return err + } + d.root = album + } else { + d.root = &Root{ + Name: "root", + Modified: d.Modified, + IsFolder: true, + } + } + + // uk + info, err := d.uInfo() + if err != nil { + return err + } + d.Uk, err = strconv.ParseInt(info.YouaID, 10, 64) + return err +} + +func (d *BaiduPhoto) GetRoot(ctx context.Context) (model.Obj, error) { + return d.root, nil +} + +func (d *BaiduPhoto) Drop(ctx context.Context) error { + // d.AccessToken = "" + d.Uk = 0 + d.root = nil + return nil +} + +func (d *BaiduPhoto) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + var err error + + /* album */ + if album, ok := dir.(*Album); ok { + var files []AlbumFile + files, err = d.GetAllAlbumFile(ctx, album, "") + if err != nil { + return nil, err + } + + return utils.MustSliceConvert(files, func(file AlbumFile) model.Obj { + return &file + }), nil + } + + /* root */ + var albums []Album + if d.ShowType != "root_only_file" { + albums, err = d.GetAllAlbum(ctx) + if err != nil { + return nil, err + } + } + + var files []File + if d.ShowType != "root_only_album" { + files, err = d.GetAllFile(ctx) + if err != nil { + return nil, err + } + } + + return append( + utils.MustSliceConvert(albums, func(album Album) model.Obj { + return &album + }), + utils.MustSliceConvert(files, func(album File) model.Obj { + return &album + })..., + ), nil + +} + +func (d *BaiduPhoto) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + switch file := file.(type) { + case *File: + return d.linkFile(ctx, file, args) + case *AlbumFile: + // 处理共享相册 + if d.Uk != file.Uk { + // 有概率无法获取到链接 + // return d.linkAlbum(ctx, file, args) + + f, err := d.CopyAlbumFile(ctx, file) + if err != nil { + return nil, err + } + return d.linkFile(ctx, f, args) + } + return d.linkFile(ctx, &file.File, args) + } + return nil, errs.NotFile +} + +var joinReg = regexp.MustCompile(`(?i)join:([\S]*)`) + +func (d *BaiduPhoto) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + if _, ok := parentDir.(*Root); ok { + code := joinReg.FindStringSubmatch(dirName) + if len(code) > 1 { + return d.JoinAlbum(ctx, code[1]) + } + return d.CreateAlbum(ctx, dirName) + } + return nil, errs.NotSupport +} + +func (d *BaiduPhoto) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + switch file := srcObj.(type) { + case *File: + if album, ok := dstDir.(*Album); ok { + //rootfile -> album + return d.AddAlbumFile(ctx, album, file) + } + case *AlbumFile: + switch album := dstDir.(type) { + case *Root: + //albumfile -> root + return d.CopyAlbumFile(ctx, file) + case *Album: + // albumfile -> root -> album + rootfile, err := d.CopyAlbumFile(ctx, file) + if err != nil { + return nil, err + } + return d.AddAlbumFile(ctx, album, rootfile) + } + } + return nil, errs.NotSupport +} + +func (d *BaiduPhoto) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + if file, ok := srcObj.(*AlbumFile); ok { + switch dstDir.(type) { + case *Album, *Root: // albumfile -> root -> album or albumfile -> root + newObj, err := d.Copy(ctx, srcObj, dstDir) + if err != nil { + return nil, err + } + // 删除原相册文件 + _ = d.DeleteAlbumFile(ctx, file) + return newObj, nil + } + } + return nil, errs.NotSupport +} + +func (d *BaiduPhoto) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + // 仅支持相册改名 + if album, ok := srcObj.(*Album); ok { + return d.SetAlbumName(ctx, album, newName) + } + return nil, errs.NotSupport +} + +func (d *BaiduPhoto) Remove(ctx context.Context, obj model.Obj) error { + switch obj := obj.(type) { + case *File: + return d.DeleteFile(ctx, obj) + case *AlbumFile: + return d.DeleteAlbumFile(ctx, obj) + case *Album: + return d.DeleteAlbum(ctx, obj) + } + return errs.NotSupport +} + +func (d *BaiduPhoto) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + // 不支持大小为0的文件 + if stream.GetSize() == 0 { + return nil, fmt.Errorf("file size cannot be zero") + } + + // TODO: + // 暂时没有找到妙传方式 + + // 需要获取完整文件md5,必须支持 io.Seek + tempFile, err := stream.CacheFullInTempFile() + if err != nil { + return nil, err + } + + const DEFAULT int64 = 1 << 22 + const SliceSize int64 = 1 << 18 + + // 计算需要的数据 + streamSize := stream.GetSize() + count := int(math.Ceil(float64(streamSize) / float64(DEFAULT))) + lastBlockSize := streamSize % DEFAULT + if lastBlockSize == 0 { + lastBlockSize = DEFAULT + } + + // step.1 计算MD5 + sliceMD5List := make([]string, 0, count) + byteSize := int64(DEFAULT) + fileMd5H := md5.New() + sliceMd5H := md5.New() + sliceMd5H2 := md5.New() + slicemd5H2Write := utils.LimitWriter(sliceMd5H2, SliceSize) + for i := 1; i <= count; i++ { + if utils.IsCanceled(ctx) { + return nil, ctx.Err() + } + if i == count { + byteSize = lastBlockSize + } + _, err := utils.CopyWithBufferN(io.MultiWriter(fileMd5H, sliceMd5H, slicemd5H2Write), tempFile, byteSize) + if err != nil && err != io.EOF { + return nil, err + } + sliceMD5List = append(sliceMD5List, hex.EncodeToString(sliceMd5H.Sum(nil))) + sliceMd5H.Reset() + } + contentMd5 := hex.EncodeToString(fileMd5H.Sum(nil)) + sliceMd5 := hex.EncodeToString(sliceMd5H2.Sum(nil)) + blockListStr, _ := utils.Json.MarshalToString(sliceMD5List) + + // step.2 预上传 + params := map[string]string{ + "autoinit": "1", + "isdir": "0", + "rtype": "1", + "ctype": "11", + "path": fmt.Sprintf("/%s", stream.GetName()), + "size": fmt.Sprint(stream.GetSize()), + "slice-md5": sliceMd5, + "content-md5": contentMd5, + "block_list": blockListStr, + } + + // 尝试获取之前的进度 + precreateResp, ok := base.GetUploadProgress[*PrecreateResp](d, strconv.FormatInt(d.Uk, 10), contentMd5) + if !ok { + _, err = d.Post(FILE_API_URL_V1+"/precreate", func(r *resty.Request) { + r.SetContext(ctx) + r.SetFormData(params) + }, &precreateResp) + if err != nil { + return nil, err + } + } + + switch precreateResp.ReturnType { + case 1: //step.3 上传文件切片 + threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread, + retry.Attempts(3), + retry.Delay(time.Second), + retry.DelayType(retry.BackOffDelay)) + for i, partseq := range precreateResp.BlockList { + if utils.IsCanceled(upCtx) { + break + } + + i, partseq, offset, byteSize := i, partseq, int64(partseq)*DEFAULT, DEFAULT + if partseq+1 == count { + byteSize = lastBlockSize + } + + threadG.Go(func(ctx context.Context) error { + uploadParams := map[string]string{ + "method": "upload", + "path": params["path"], + "partseq": fmt.Sprint(partseq), + "uploadid": precreateResp.UploadID, + } + + _, err = d.Post("https://c3.pcs.baidu.com/rest/2.0/pcs/superfile2", func(r *resty.Request) { + r.SetContext(ctx) + r.SetQueryParams(uploadParams) + r.SetFileReader("file", stream.GetName(), io.NewSectionReader(tempFile, offset, byteSize)) + }, nil) + if err != nil { + return err + } + up(float64(threadG.Success()) * 100 / float64(len(precreateResp.BlockList))) + precreateResp.BlockList[i] = -1 + return nil + }) + } + if err = threadG.Wait(); err != nil { + if errors.Is(err, context.Canceled) { + precreateResp.BlockList = utils.SliceFilter(precreateResp.BlockList, func(s int) bool { return s >= 0 }) + base.SaveUploadProgress(d, strconv.FormatInt(d.Uk, 10), contentMd5) + } + return nil, err + } + fallthrough + case 2: //step.4 创建文件 + params["uploadid"] = precreateResp.UploadID + _, err = d.Post(FILE_API_URL_V1+"/create", func(r *resty.Request) { + r.SetContext(ctx) + r.SetFormData(params) + }, &precreateResp) + if err != nil { + return nil, err + } + fallthrough + case 3: //step.5 增加到相册 + rootfile := precreateResp.Data.toFile() + if album, ok := dstDir.(*Album); ok { + return d.AddAlbumFile(ctx, album, rootfile) + } + return rootfile, nil + } + return nil, errs.NotSupport +} + +var _ driver.Driver = (*BaiduPhoto)(nil) +var _ driver.GetRooter = (*BaiduPhoto)(nil) +var _ driver.MkdirResult = (*BaiduPhoto)(nil) +var _ driver.CopyResult = (*BaiduPhoto)(nil) +var _ driver.MoveResult = (*BaiduPhoto)(nil) +var _ driver.Remove = (*BaiduPhoto)(nil) +var _ driver.PutResult = (*BaiduPhoto)(nil) +var _ driver.RenameResult = (*BaiduPhoto)(nil) diff --git a/drivers/baidu_photo/help.go b/drivers/baidu_photo/help.go new file mode 100644 index 0000000000000000000000000000000000000000..40588ee99ee0dd01ba0f33580c55096bf5ad729d --- /dev/null +++ b/drivers/baidu_photo/help.go @@ -0,0 +1,78 @@ +package baiduphoto + +import ( + "fmt" + "math" + "math/rand" + "strings" + "time" + + "github.com/alist-org/alist/v3/pkg/utils" +) + +// Tid生成 +func getTid() string { + return fmt.Sprintf("3%d%.0f", time.Now().Unix(), math.Floor(9000000*rand.Float64()+1000000)) +} + +func toTime(t int64) *time.Time { + tm := time.Unix(t, 0) + return &tm +} + +func fsidsFormatNotUk(ids ...int64) string { + buf := utils.MustSliceConvert(ids, func(id int64) string { + return fmt.Sprintf(`{"fsid":%d}`, id) + }) + return fmt.Sprintf("[%s]", strings.Join(buf, ",")) +} + +func getFileName(path string) string { + return path[strings.LastIndex(path, "/")+1:] +} + +func MustString(str string, err error) string { + return str +} + +/* +* 处理文件变化 +* 最大程度利用重复数据 +**/ +func copyFile(file *AlbumFile, cf *CopyFile) *File { + return &File{ + Fsid: cf.Fsid, + Path: cf.Path, + Ctime: cf.Ctime, + Mtime: cf.Ctime, + Size: file.Size, + Thumburl: file.Thumburl, + } +} + +func moveFileToAlbumFile(file *File, album *Album, uk int64) *AlbumFile { + return &AlbumFile{ + File: *file, + AlbumID: album.AlbumID, + Tid: album.Tid, + Uk: uk, + } +} + +func renameAlbum(album *Album, newName string) *Album { + return &Album{ + AlbumID: album.AlbumID, + Tid: album.Tid, + JoinTime: album.JoinTime, + CreationTime: album.CreationTime, + Title: newName, + Mtime: time.Now().Unix(), + } +} + +func BoolToIntStr(b bool) string { + if b { + return "1" + } + return "0" +} diff --git a/drivers/baidu_photo/meta.go b/drivers/baidu_photo/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..3bc2f6227c549e71482718067321125b51f9f752 --- /dev/null +++ b/drivers/baidu_photo/meta.go @@ -0,0 +1,29 @@ +package baiduphoto + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + // RefreshToken string `json:"refresh_token" required:"true"` + Cookie string `json:"cookie" required:"true"` + ShowType string `json:"show_type" type:"select" options:"root,root_only_album,root_only_file" default:"root"` + AlbumID string `json:"album_id"` + //AlbumPassword string `json:"album_password"` + DeleteOrigin bool `json:"delete_origin"` + // ClientID string `json:"client_id" required:"true" default:"iYCeC9g08h5vuP9UqvPHKKSVrKFXGa1v"` + // ClientSecret string `json:"client_secret" required:"true" default:"jXiFMOPVPCWlO2M5CwWQzffpNPaGTRBG"` + UploadThread string `json:"upload_thread" default:"3" help:"1<=thread<=32"` +} + +var config = driver.Config{ + Name: "BaiduPhoto", + LocalSort: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &BaiduPhoto{} + }) +} diff --git a/drivers/baidu_photo/types.go b/drivers/baidu_photo/types.go new file mode 100644 index 0000000000000000000000000000000000000000..0e5cbb2cdd5e3165242d89fd5c3a70fe82185461 --- /dev/null +++ b/drivers/baidu_photo/types.go @@ -0,0 +1,196 @@ +package baiduphoto + +import ( + "fmt" + "time" + + "github.com/alist-org/alist/v3/pkg/utils" + + "github.com/alist-org/alist/v3/internal/model" +) + +type TokenErrResp struct { + ErrorDescription string `json:"error_description"` + ErrorMsg string `json:"error"` +} + +func (e *TokenErrResp) Error() string { + return fmt.Sprint(e.ErrorMsg, " : ", e.ErrorDescription) +} + +type Erron struct { + Errno int `json:"errno"` + RequestID int `json:"request_id"` +} + +// 用户信息 +type UInfo struct { + // uk + YouaID string `json:"youa_id"` +} + +type Page struct { + HasMore int `json:"has_more"` + Cursor string `json:"cursor"` +} + +func (p Page) HasNextPage() bool { + return p.HasMore == 1 +} + +type Root = model.Object + +type ( + FileListResp struct { + Page + List []File `json:"list"` + } + + File struct { + Fsid int64 `json:"fsid"` // 文件ID + Path string `json:"path"` // 文件路径 + Size int64 `json:"size"` + Ctime int64 `json:"ctime"` // 创建时间 s + Mtime int64 `json:"mtime"` // 修改时间 s + Thumburl []string `json:"thumburl"` + Md5 string `json:"md5"` + } +) + +func (c *File) GetSize() int64 { return c.Size } +func (c *File) GetName() string { return getFileName(c.Path) } +func (c *File) CreateTime() time.Time { return time.Unix(c.Ctime, 0) } +func (c *File) ModTime() time.Time { return time.Unix(c.Mtime, 0) } +func (c *File) IsDir() bool { return false } +func (c *File) GetID() string { return "" } +func (c *File) GetPath() string { return "" } +func (c *File) Thumb() string { + if len(c.Thumburl) > 0 { + return c.Thumburl[0] + } + return "" +} + +func (c *File) GetHash() utils.HashInfo { + return utils.NewHashInfo(utils.MD5, DecryptMd5(c.Md5)) +} + +/*相册部分*/ +type ( + AlbumListResp struct { + Page + List []Album `json:"list"` + Reset int64 `json:"reset"` + TotalCount int64 `json:"total_count"` + } + + Album struct { + AlbumID string `json:"album_id"` + Tid int64 `json:"tid"` + Title string `json:"title"` + JoinTime int64 `json:"join_time"` + CreationTime int64 `json:"create_time"` + Mtime int64 `json:"mtime"` + + parseTime *time.Time + } + + AlbumFileListResp struct { + Page + List []AlbumFile `json:"list"` + Reset int64 `json:"reset"` + TotalCount int64 `json:"total_count"` + } + + AlbumFile struct { + File + AlbumID string `json:"album_id"` + Tid int64 `json:"tid"` + Uk int64 `json:"uk"` + } +) + +func (a *Album) GetHash() utils.HashInfo { + return utils.HashInfo{} +} + +func (a *Album) GetSize() int64 { return 0 } +func (a *Album) GetName() string { return a.Title } +func (a *Album) CreateTime() time.Time { return time.Unix(a.CreationTime, 0) } +func (a *Album) ModTime() time.Time { return time.Unix(a.Mtime, 0) } +func (a *Album) IsDir() bool { return true } +func (a *Album) GetID() string { return "" } +func (a *Album) GetPath() string { return "" } + +type ( + CopyFileResp struct { + List []CopyFile `json:"list"` + } + CopyFile struct { + FromFsid int64 `json:"from_fsid"` // 源ID + Ctime int64 `json:"ctime"` + Fsid int64 `json:"fsid"` // 目标ID + Path string `json:"path"` + ShootTime int `json:"shoot_time"` + } +) + +/*上传部分*/ +type ( + UploadFile struct { + FsID int64 `json:"fs_id"` + Size int64 `json:"size"` + Md5 string `json:"md5"` + ServerFilename string `json:"server_filename"` + Path string `json:"path"` + Ctime int64 `json:"ctime"` + Mtime int64 `json:"mtime"` + Isdir int `json:"isdir"` + Category int `json:"category"` + ServerMd5 string `json:"server_md5"` + ShootTime int `json:"shoot_time"` + } + + CreateFileResp struct { + Data UploadFile `json:"data"` + } + + PrecreateResp struct { + ReturnType int `json:"return_type"` //存在返回2 不存在返回1 已经保存3 + //存在返回 + CreateFileResp + + //不存在返回 + Path string `json:"path"` + UploadID string `json:"uploadid"` + BlockList []int `json:"block_list"` + } +) + +func (f *UploadFile) toFile() *File { + return &File{ + Fsid: f.FsID, + Path: f.Path, + Size: f.Size, + Ctime: f.Ctime, + Mtime: f.Mtime, + Thumburl: nil, + } +} + +/* 共享相册部分 */ +type InviteResp struct { + Pdata struct { + // 邀请码 + InviteCode string `json:"invite_code"` + // 有效时间 + ExpireTime int `json:"expire_time"` + ShareID string `json:"share_id"` + } `json:"pdata"` +} + +/* 加入相册部分 */ +type JoinOrCreateAlbumResp struct { + AlbumID string `json:"album_id"` + AlreadyExists int `json:"already_exists"` +} diff --git a/drivers/baidu_photo/utils.go b/drivers/baidu_photo/utils.go new file mode 100644 index 0000000000000000000000000000000000000000..0b960593bce31a0bfe28cb5edc7082c9c6332b5f --- /dev/null +++ b/drivers/baidu_photo/utils.go @@ -0,0 +1,514 @@ +package baiduphoto + +import ( + "context" + "encoding/hex" + "fmt" + "net/http" + "strconv" + "strings" + "unicode" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +const ( + API_URL = "https://photo.baidu.com/youai" + USER_API_URL = API_URL + "/user/v1" + ALBUM_API_URL = API_URL + "/album/v1" + FILE_API_URL_V1 = API_URL + "/file/v1" + FILE_API_URL_V2 = API_URL + "/file/v2" +) + +func (d *BaiduPhoto) Request(client *resty.Client, furl string, method string, callback base.ReqCallback, resp interface{}) (*resty.Response, error) { + req := client.R(). + // SetQueryParam("access_token", d.AccessToken) + SetHeader("Cookie", d.Cookie) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + res, err := req.Execute(method, furl) + if err != nil { + return nil, err + } + + erron := utils.Json.Get(res.Body(), "errno").ToInt() + switch erron { + case 0: + break + case 50805: + return nil, fmt.Errorf("you have joined album") + case 50820: + return nil, fmt.Errorf("no shared albums found") + case 50100: + return nil, fmt.Errorf("illegal title, only supports 50 characters") + // case -6: + // if err = d.refreshToken(); err != nil { + // return nil, err + // } + default: + return nil, fmt.Errorf("errno: %d, refer to https://photo.baidu.com/union/doc", erron) + } + return res, nil +} + +//func (d *BaiduPhoto) Request(furl string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { +// res, err := d.request(furl, method, callback, resp) +// if err != nil { +// return nil, err +// } +// return res.Body(), nil +//} + +// func (d *BaiduPhoto) refreshToken() error { +// u := "https://openapi.baidu.com/oauth/2.0/token" +// var resp base.TokenResp +// var e TokenErrResp +// _, err := base.RestyClient.R().SetResult(&resp).SetError(&e).SetQueryParams(map[string]string{ +// "grant_type": "refresh_token", +// "refresh_token": d.RefreshToken, +// "client_id": d.ClientID, +// "client_secret": d.ClientSecret, +// }).Get(u) +// if err != nil { +// return err +// } +// if e.ErrorMsg != "" { +// return &e +// } +// if resp.RefreshToken == "" { +// return errs.EmptyToken +// } +// d.AccessToken, d.RefreshToken = resp.AccessToken, resp.RefreshToken +// op.MustSaveDriverStorage(d) +// return nil +// } + +func (d *BaiduPhoto) Get(furl string, callback base.ReqCallback, resp interface{}) (*resty.Response, error) { + return d.Request(base.RestyClient, furl, http.MethodGet, callback, resp) +} + +func (d *BaiduPhoto) Post(furl string, callback base.ReqCallback, resp interface{}) (*resty.Response, error) { + return d.Request(base.RestyClient, furl, http.MethodPost, callback, resp) +} + +// 获取所有文件 +func (d *BaiduPhoto) GetAllFile(ctx context.Context) (files []File, err error) { + var cursor string + for { + var resp FileListResp + _, err = d.Get(FILE_API_URL_V1+"/list", func(r *resty.Request) { + r.SetContext(ctx) + r.SetQueryParams(map[string]string{ + "need_thumbnail": "1", + "need_filter_hidden": "0", + "cursor": cursor, + }) + }, &resp) + if err != nil { + return + } + + files = append(files, resp.List...) + if !resp.HasNextPage() { + return + } + cursor = resp.Cursor + } +} + +// 删除根文件 +func (d *BaiduPhoto) DeleteFile(ctx context.Context, file *File) error { + _, err := d.Get(FILE_API_URL_V1+"/delete", func(req *resty.Request) { + req.SetContext(ctx) + req.SetQueryParams(map[string]string{ + "fsid_list": fmt.Sprintf("[%d]", file.Fsid), + }) + }, nil) + return err +} + +// 获取所有相册 +func (d *BaiduPhoto) GetAllAlbum(ctx context.Context) (albums []Album, err error) { + var cursor string + for { + var resp AlbumListResp + _, err = d.Get(ALBUM_API_URL+"/list", func(r *resty.Request) { + r.SetContext(ctx) + r.SetQueryParams(map[string]string{ + "need_amount": "1", + "limit": "100", + "cursor": cursor, + }) + }, &resp) + if err != nil { + return + } + if albums == nil { + albums = make([]Album, 0, resp.TotalCount) + } + + cursor = resp.Cursor + albums = append(albums, resp.List...) + + if !resp.HasNextPage() { + return + } + } +} + +// 获取相册中所有文件 +func (d *BaiduPhoto) GetAllAlbumFile(ctx context.Context, album *Album, passwd string) (files []AlbumFile, err error) { + var cursor string + for { + var resp AlbumFileListResp + _, err = d.Get(ALBUM_API_URL+"/listfile", func(r *resty.Request) { + r.SetContext(ctx) + r.SetQueryParams(map[string]string{ + "album_id": album.AlbumID, + "need_amount": "1", + "limit": "1000", + "passwd": passwd, + "cursor": cursor, + }) + }, &resp) + if err != nil { + return + } + if files == nil { + files = make([]AlbumFile, 0, resp.TotalCount) + } + + cursor = resp.Cursor + files = append(files, resp.List...) + + if !resp.HasNextPage() { + return + } + } +} + +// 创建相册 +func (d *BaiduPhoto) CreateAlbum(ctx context.Context, name string) (*Album, error) { + var resp JoinOrCreateAlbumResp + _, err := d.Post(ALBUM_API_URL+"/create", func(r *resty.Request) { + r.SetContext(ctx).SetResult(&resp) + r.SetQueryParams(map[string]string{ + "title": name, + "tid": getTid(), + "source": "0", + }) + }, nil) + if err != nil { + return nil, err + } + return d.GetAlbumDetail(ctx, resp.AlbumID) +} + +// 相册改名 +func (d *BaiduPhoto) SetAlbumName(ctx context.Context, album *Album, name string) (*Album, error) { + _, err := d.Post(ALBUM_API_URL+"/settitle", func(r *resty.Request) { + r.SetContext(ctx) + r.SetFormData(map[string]string{ + "title": name, + "album_id": album.AlbumID, + "tid": fmt.Sprint(album.Tid), + }) + }, nil) + if err != nil { + return nil, err + } + return renameAlbum(album, name), nil +} + +// 删除相册 +func (d *BaiduPhoto) DeleteAlbum(ctx context.Context, album *Album) error { + _, err := d.Post(ALBUM_API_URL+"/delete", func(r *resty.Request) { + r.SetContext(ctx) + r.SetFormData(map[string]string{ + "album_id": album.AlbumID, + "tid": fmt.Sprint(album.Tid), + "delete_origin_image": BoolToIntStr(d.DeleteOrigin), // 是否删除原图 0 不删除 1 删除 + }) + }, nil) + return err +} + +// 删除相册文件 +func (d *BaiduPhoto) DeleteAlbumFile(ctx context.Context, file *AlbumFile) error { + _, err := d.Post(ALBUM_API_URL+"/delfile", func(r *resty.Request) { + r.SetContext(ctx) + r.SetFormData(map[string]string{ + "album_id": fmt.Sprint(file.AlbumID), + "tid": fmt.Sprint(file.Tid), + "list": fmt.Sprintf(`[{"fsid":%d,"uk":%d}]`, file.Fsid, file.Uk), + "del_origin": BoolToIntStr(d.DeleteOrigin), // 是否删除原图 0 不删除 1 删除 + }) + }, nil) + return err +} + +// 增加相册文件 +func (d *BaiduPhoto) AddAlbumFile(ctx context.Context, album *Album, file *File) (*AlbumFile, error) { + _, err := d.Get(ALBUM_API_URL+"/addfile", func(r *resty.Request) { + r.SetContext(ctx) + r.SetQueryParams(map[string]string{ + "album_id": fmt.Sprint(album.AlbumID), + "tid": fmt.Sprint(album.Tid), + "list": fsidsFormatNotUk(file.Fsid), + }) + }, nil) + if err != nil { + return nil, err + } + return moveFileToAlbumFile(file, album, d.Uk), nil +} + +// 保存相册文件为根文件 +func (d *BaiduPhoto) CopyAlbumFile(ctx context.Context, file *AlbumFile) (*File, error) { + var resp CopyFileResp + _, err := d.Post(ALBUM_API_URL+"/copyfile", func(r *resty.Request) { + r.SetContext(ctx) + r.SetFormData(map[string]string{ + "album_id": file.AlbumID, + "tid": fmt.Sprint(file.Tid), + "uk": fmt.Sprint(file.Uk), + "list": fsidsFormatNotUk(file.Fsid), + }) + r.SetResult(&resp) + }, nil) + if err != nil { + return nil, err + } + return copyFile(file, &resp.List[0]), nil +} + +// 加入相册 +func (d *BaiduPhoto) JoinAlbum(ctx context.Context, code string) (*Album, error) { + var resp InviteResp + _, err := d.Get(ALBUM_API_URL+"/querypcode", func(req *resty.Request) { + req.SetContext(ctx) + req.SetQueryParams(map[string]string{ + "pcode": code, + "web": "1", + }) + }, &resp) + if err != nil { + return nil, err + } + var resp2 JoinOrCreateAlbumResp + _, err = d.Get(ALBUM_API_URL+"/join", func(req *resty.Request) { + req.SetContext(ctx) + req.SetQueryParams(map[string]string{ + "invite_code": resp.Pdata.InviteCode, + }) + }, &resp2) + if err != nil { + return nil, err + } + return d.GetAlbumDetail(ctx, resp2.AlbumID) +} + +// 获取相册详细信息 +func (d *BaiduPhoto) GetAlbumDetail(ctx context.Context, albumID string) (*Album, error) { + var album Album + _, err := d.Get(ALBUM_API_URL+"/detail", func(req *resty.Request) { + req.SetContext(ctx).SetResult(&album) + req.SetQueryParams(map[string]string{ + "album_id": albumID, + }) + }, &album) + if err != nil { + return nil, err + } + return &album, nil +} + +func (d *BaiduPhoto) linkAlbum(ctx context.Context, file *AlbumFile, args model.LinkArgs) (*model.Link, error) { + headers := map[string]string{ + "User-Agent": base.UserAgent, + } + if args.Header.Get("User-Agent") != "" { + headers["User-Agent"] = args.Header.Get("User-Agent") + } + if !utils.IsLocalIPAddr(args.IP) { + headers["X-Forwarded-For"] = args.IP + } + + resp, err := d.Request(base.NoRedirectClient, ALBUM_API_URL+"/download", http.MethodHead, func(r *resty.Request) { + r.SetContext(ctx) + r.SetHeaders(headers) + r.SetQueryParams(map[string]string{ + "fsid": fmt.Sprint(file.Fsid), + "album_id": file.AlbumID, + "tid": fmt.Sprint(file.Tid), + "uk": fmt.Sprint(file.Uk), + }) + }, nil) + + if err != nil { + return nil, err + } + + if resp.StatusCode() != 302 { + return nil, fmt.Errorf("not found 302 redirect") + } + + location := resp.Header().Get("Location") + + link := &model.Link{ + URL: location, + Header: http.Header{ + "User-Agent": []string{headers["User-Agent"]}, + "Referer": []string{"https://photo.baidu.com/"}, + }, + } + return link, nil +} + +func (d *BaiduPhoto) linkFile(ctx context.Context, file *File, args model.LinkArgs) (*model.Link, error) { + headers := map[string]string{ + "User-Agent": base.UserAgent, + } + if args.Header.Get("User-Agent") != "" { + headers["User-Agent"] = args.Header.Get("User-Agent") + } + if !utils.IsLocalIPAddr(args.IP) { + headers["X-Forwarded-For"] = args.IP + } + + var downloadUrl struct { + Dlink string `json:"dlink"` + } + _, err := d.Get(FILE_API_URL_V2+"/download", func(r *resty.Request) { + r.SetContext(ctx) + r.SetHeaders(headers) + r.SetQueryParams(map[string]string{ + "fsid": fmt.Sprint(file.Fsid), + }) + }, &downloadUrl) + + // resp, err := d.Request(base.NoRedirectClient, FILE_API_URL_V1+"/download", http.MethodHead, func(r *resty.Request) { + // r.SetContext(ctx) + // r.SetHeaders(headers) + // r.SetQueryParams(map[string]string{ + // "fsid": fmt.Sprint(file.Fsid), + // }) + // }, nil) + + if err != nil { + return nil, err + } + + // if resp.StatusCode() != 302 { + // return nil, fmt.Errorf("not found 302 redirect") + // } + + // location := resp.Header().Get("Location") + link := &model.Link{ + URL: downloadUrl.Dlink, + Header: http.Header{ + "User-Agent": []string{headers["User-Agent"]}, + "Referer": []string{"https://photo.baidu.com/"}, + }, + } + return link, nil +} + +/*func (d *BaiduPhoto) linkStreamAlbum(ctx context.Context, file *AlbumFile) (*model.Link, error) { + return &model.Link{ + Header: http.Header{}, + Writer: func(w io.Writer) error { + res, err := d.Get(ALBUM_API_URL+"/streaming", func(r *resty.Request) { + r.SetContext(ctx) + r.SetQueryParams(map[string]string{ + "fsid": fmt.Sprint(file.Fsid), + "album_id": file.AlbumID, + "tid": fmt.Sprint(file.Tid), + "uk": fmt.Sprint(file.Uk), + }).SetDoNotParseResponse(true) + }, nil) + if err != nil { + return err + } + defer res.RawBody().Close() + _, err = io.Copy(w, res.RawBody()) + return err + }, + }, nil +}*/ + +/*func (d *BaiduPhoto) linkStream(ctx context.Context, file *File) (*model.Link, error) { + return &model.Link{ + Header: http.Header{}, + Writer: func(w io.Writer) error { + res, err := d.Get(FILE_API_URL_V1+"/streaming", func(r *resty.Request) { + r.SetContext(ctx) + r.SetQueryParams(map[string]string{ + "fsid": fmt.Sprint(file.Fsid), + }).SetDoNotParseResponse(true) + }, nil) + if err != nil { + return err + } + defer res.RawBody().Close() + _, err = io.Copy(w, res.RawBody()) + return err + }, + }, nil +}*/ + +// 获取uk +func (d *BaiduPhoto) uInfo() (*UInfo, error) { + var info UInfo + _, err := d.Get(USER_API_URL+"/getuinfo", func(req *resty.Request) { + + }, &info) + if err != nil { + return nil, err + } + return &info, nil +} + +func DecryptMd5(encryptMd5 string) string { + if _, err := hex.DecodeString(encryptMd5); err == nil { + return encryptMd5 + } + + var out strings.Builder + out.Grow(len(encryptMd5)) + for i, n := 0, int64(0); i < len(encryptMd5); i++ { + if i == 9 { + n = int64(unicode.ToLower(rune(encryptMd5[i])) - 'g') + } else { + n, _ = strconv.ParseInt(encryptMd5[i:i+1], 16, 64) + } + out.WriteString(strconv.FormatInt(n^int64(15&i), 16)) + } + + encryptMd5 = out.String() + return encryptMd5[8:16] + encryptMd5[:8] + encryptMd5[24:32] + encryptMd5[16:24] +} + +func EncryptMd5(originalMd5 string) string { + reversed := originalMd5[8:16] + originalMd5[:8] + originalMd5[24:32] + originalMd5[16:24] + + var out strings.Builder + out.Grow(len(reversed)) + for i, n := 0, int64(0); i < len(reversed); i++ { + n, _ = strconv.ParseInt(reversed[i:i+1], 16, 64) + n ^= int64(15 & i) + if i == 9 { + out.WriteRune(rune(n) + 'g') + } else { + out.WriteString(strconv.FormatInt(n, 16)) + } + } + return out.String() +} diff --git a/drivers/baidu_share/driver.go b/drivers/baidu_share/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..a77303aa80640385dbcdae65b730537048c2b563 --- /dev/null +++ b/drivers/baidu_share/driver.go @@ -0,0 +1,251 @@ +package baidu_share + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "path" + "time" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/go-resty/resty/v2" +) + +type BaiduShare struct { + model.Storage + Addition + client *resty.Client + info struct { + Root string + Seckey string + Shareid string + Uk string + } +} + +func (d *BaiduShare) Config() driver.Config { + return config +} + +func (d *BaiduShare) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *BaiduShare) Init(ctx context.Context) error { + // TODO login / refresh token + //op.MustSaveDriverStorage(d) + d.client = resty.New(). + SetBaseURL("https://pan.baidu.com"). + SetHeader("User-Agent", "netdisk"). + SetCookie(&http.Cookie{Name: "BDUSS", Value: d.BDUSS}). + SetCookie(&http.Cookie{Name: "ndut_fmt"}) + respJson := struct { + Errno int64 `json:"errno"` + Data struct { + List [1]struct { + Path string `json:"path"` + } `json:"list"` + Uk json.Number `json:"uk"` + Shareid json.Number `json:"shareid"` + Seckey string `json:"seckey"` + } `json:"data"` + }{} + resp, err := d.client.R(). + SetBody(url.Values{ + "pwd": {d.Pwd}, + "root": {"1"}, + "shorturl": {d.Surl}, + }.Encode()). + SetResult(&respJson). + Post("share/wxlist?channel=weixin&version=2.2.2&clienttype=25&web=1") + if err == nil { + if resp.IsSuccess() && respJson.Errno == 0 { + d.info.Root = path.Dir(respJson.Data.List[0].Path) + d.info.Seckey = respJson.Data.Seckey + d.info.Shareid = respJson.Data.Shareid.String() + d.info.Uk = respJson.Data.Uk.String() + } else { + err = fmt.Errorf(" %s; %s; ", resp.Status(), resp.Body()) + } + } + return err +} + +func (d *BaiduShare) Drop(ctx context.Context) error { + return nil +} + +func (d *BaiduShare) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + // TODO return the files list, required + reqDir := dir.GetPath() + isRoot := "0" + if reqDir == d.RootFolderPath { + reqDir = path.Join(d.info.Root, reqDir) + } + if reqDir == d.info.Root { + isRoot = "1" + } + objs := []model.Obj{} + var err error + var page uint64 = 1 + more := true + for more && err == nil { + respJson := struct { + Errno int64 `json:"errno"` + Data struct { + More bool `json:"has_more"` + List []struct { + Fsid json.Number `json:"fs_id"` + Isdir json.Number `json:"isdir"` + Path string `json:"path"` + Name string `json:"server_filename"` + Mtime json.Number `json:"server_mtime"` + Size json.Number `json:"size"` + } `json:"list"` + } `json:"data"` + }{} + resp, e := d.client.R(). + SetBody(url.Values{ + "dir": {reqDir}, + "num": {"1000"}, + "order": {"time"}, + "page": {fmt.Sprint(page)}, + "pwd": {d.Pwd}, + "root": {isRoot}, + "shorturl": {d.Surl}, + }.Encode()). + SetResult(&respJson). + Post("share/wxlist?channel=weixin&version=2.2.2&clienttype=25&web=1") + err = e + if err == nil { + if resp.IsSuccess() && respJson.Errno == 0 { + page++ + more = respJson.Data.More + for _, v := range respJson.Data.List { + size, _ := v.Size.Int64() + mtime, _ := v.Mtime.Int64() + objs = append(objs, &model.Object{ + ID: v.Fsid.String(), + Path: v.Path, + Name: v.Name, + Size: size, + Modified: time.Unix(mtime, 0), + IsFolder: v.Isdir.String() == "1", + }) + } + } else { + err = fmt.Errorf(" %s; %s; ", resp.Status(), resp.Body()) + } + } + } + return objs, err +} + +func (d *BaiduShare) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + // TODO return link of file, required + link := model.Link{Header: d.client.Header} + sign := "" + stamp := "" + signJson := struct { + Errno int64 `json:"errno"` + Data struct { + Stamp json.Number `json:"timestamp"` + Sign string `json:"sign"` + } `json:"data"` + }{} + resp, err := d.client.R(). + SetQueryParam("surl", d.Surl). + SetResult(&signJson). + Get("share/tplconfig?fields=sign,timestamp&channel=chunlei&web=1&app_id=250528&clienttype=0") + if err == nil { + if resp.IsSuccess() && signJson.Errno == 0 { + stamp = signJson.Data.Stamp.String() + sign = signJson.Data.Sign + } else { + err = fmt.Errorf(" %s; %s; ", resp.Status(), resp.Body()) + } + } + if err == nil { + respJson := struct { + Errno int64 `json:"errno"` + List [1]struct { + Dlink string `json:"dlink"` + } `json:"list"` + }{} + resp, err = d.client.R(). + SetQueryParam("sign", sign). + SetQueryParam("timestamp", stamp). + SetBody(url.Values{ + "encrypt": {"0"}, + "extra": {fmt.Sprintf(`{"sekey":"%s"}`, d.info.Seckey)}, + "fid_list": {fmt.Sprintf("[%s]", file.GetID())}, + "primaryid": {d.info.Shareid}, + "product": {"share"}, + "type": {"nolimit"}, + "uk": {d.info.Uk}, + }.Encode()). + SetResult(&respJson). + Post("api/sharedownload?app_id=250528&channel=chunlei&clienttype=12&web=1") + if err == nil { + if resp.IsSuccess() && respJson.Errno == 0 && respJson.List[0].Dlink != "" { + link.URL = respJson.List[0].Dlink + } else { + err = fmt.Errorf(" %s; %s; ", resp.Status(), resp.Body()) + } + } + if err == nil { + resp, err = d.client.R(). + SetDoNotParseResponse(true). + Get(link.URL) + if err == nil { + defer resp.RawBody().Close() + if resp.IsError() { + byt, _ := io.ReadAll(resp.RawBody()) + err = fmt.Errorf(" %s; %s; ", resp.Status(), byt) + } + } + } + } + return &link, err +} + +func (d *BaiduShare) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + // TODO create folder, optional + return errs.NotSupport +} + +func (d *BaiduShare) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + // TODO move obj, optional + return errs.NotSupport +} + +func (d *BaiduShare) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + // TODO rename obj, optional + return errs.NotSupport +} + +func (d *BaiduShare) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + // TODO copy obj, optional + return errs.NotSupport +} + +func (d *BaiduShare) Remove(ctx context.Context, obj model.Obj) error { + // TODO remove obj, optional + return errs.NotSupport +} + +func (d *BaiduShare) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + // TODO upload file, optional + return errs.NotSupport +} + +//func (d *Template) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*BaiduShare)(nil) diff --git a/drivers/baidu_share/meta.go b/drivers/baidu_share/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..271a692db56286804714f87da3c075dda8426574 --- /dev/null +++ b/drivers/baidu_share/meta.go @@ -0,0 +1,37 @@ +package baidu_share + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + // Usually one of two + driver.RootPath + // driver.RootID + // define other + // Field string `json:"field" type:"select" required:"true" options:"a,b,c" default:"a"` + Surl string `json:"surl"` + Pwd string `json:"pwd"` + BDUSS string `json:"BDUSS"` +} + +var config = driver.Config{ + Name: "BaiduShare", + LocalSort: true, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: true, + NeedMs: false, + DefaultRoot: "/", + CheckStatus: false, + Alert: "", + NoOverwriteUpload: false, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &BaiduShare{} + }) +} diff --git a/drivers/baidu_share/types.go b/drivers/baidu_share/types.go new file mode 100644 index 0000000000000000000000000000000000000000..1f65cb608777c95d78c6e08e61e2fe6ca7b63e4a --- /dev/null +++ b/drivers/baidu_share/types.go @@ -0,0 +1 @@ +package baidu_share diff --git a/drivers/baidu_share/util.go b/drivers/baidu_share/util.go new file mode 100644 index 0000000000000000000000000000000000000000..6bca3f9368540582dc140114176aa223440b8b3f --- /dev/null +++ b/drivers/baidu_share/util.go @@ -0,0 +1,3 @@ +package baidu_share + +// do others that not defined in Driver interface diff --git a/drivers/base/client.go b/drivers/base/client.go new file mode 100644 index 0000000000000000000000000000000000000000..8bf8f421eea779388b167b274ead423fa55526a3 --- /dev/null +++ b/drivers/base/client.go @@ -0,0 +1,50 @@ +package base + +import ( + "crypto/tls" + "net/http" + "time" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/go-resty/resty/v2" +) + +var ( + NoRedirectClient *resty.Client + RestyClient *resty.Client + HttpClient *http.Client +) +var UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" +var DefaultTimeout = time.Second * 30 + +func InitClient() { + NoRedirectClient = resty.New().SetRedirectPolicy( + resty.RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }), + ).SetTLSClientConfig(&tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}) + NoRedirectClient.SetHeader("user-agent", UserAgent) + + RestyClient = NewRestyClient() + HttpClient = NewHttpClient() +} + +func NewRestyClient() *resty.Client { + client := resty.New(). + SetHeader("user-agent", UserAgent). + SetRetryCount(3). + SetRetryResetReaders(true). + SetTimeout(DefaultTimeout). + SetTLSClientConfig(&tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}) + return client +} + +func NewHttpClient() *http.Client { + return &http.Client{ + Timeout: time.Hour * 48, + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}, + }, + } +} diff --git a/drivers/base/types.go b/drivers/base/types.go new file mode 100644 index 0000000000000000000000000000000000000000..e2757f2175f831b20ae9dc95f1a23f7202de2565 --- /dev/null +++ b/drivers/base/types.go @@ -0,0 +1,12 @@ +package base + +import "github.com/go-resty/resty/v2" + +type Json map[string]interface{} + +type TokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` +} + +type ReqCallback func(req *resty.Request) diff --git a/drivers/base/upload.go b/drivers/base/upload.go new file mode 100644 index 0000000000000000000000000000000000000000..881a256e093f70959ea5ee654f1762d41e281147 --- /dev/null +++ b/drivers/base/upload.go @@ -0,0 +1,31 @@ +package base + +import ( + "fmt" + "strings" + "time" + + "github.com/Xhofe/go-cache" + "github.com/alist-org/alist/v3/internal/driver" +) + +// storage upload progress, for upload recovery +var UploadStateCache = cache.NewMemCache(cache.WithShards[any](32)) + +// Save upload progress for 20 minutes +func SaveUploadProgress(driver driver.Driver, state any, keys ...string) bool { + return UploadStateCache.Set( + fmt.Sprint(driver.Config().Name, "-upload-", strings.Join(keys, "-")), + state, + cache.WithEx[any](time.Minute*20)) +} + +// An upload progress can only be made by one process alone, +// so here you need to get it and then delete it. +func GetUploadProgress[T any](driver driver.Driver, keys ...string) (state T, ok bool) { + v, ok := UploadStateCache.GetDel(fmt.Sprint(driver.Config().Name, "-upload-", strings.Join(keys, "-"))) + if ok { + state, ok = v.(T) + } + return +} diff --git a/drivers/base/util.go b/drivers/base/util.go new file mode 100644 index 0000000000000000000000000000000000000000..22f11114481bfe885dffb5ae232a6cd771fc588b --- /dev/null +++ b/drivers/base/util.go @@ -0,0 +1 @@ +package base diff --git a/drivers/chaoxing/driver.go b/drivers/chaoxing/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..360c6e3d01dc6b663ad23ca46554965cac5183e8 --- /dev/null +++ b/drivers/chaoxing/driver.go @@ -0,0 +1,299 @@ +package chaoxing + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/cron" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + "google.golang.org/appengine/log" +) + +type ChaoXing struct { + model.Storage + Addition + cron *cron.Cron + config driver.Config + conf Conf +} + +func (d *ChaoXing) Config() driver.Config { + return d.config +} + +func (d *ChaoXing) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *ChaoXing) refreshCookie() error { + cookie, err := d.Login() + if err != nil { + d.Status = err.Error() + op.MustSaveDriverStorage(d) + return nil + } + d.Addition.Cookie = cookie + op.MustSaveDriverStorage(d) + return nil +} + +func (d *ChaoXing) Init(ctx context.Context) error { + err := d.refreshCookie() + if err != nil { + log.Errorf(ctx, err.Error()) + } + d.cron = cron.NewCron(time.Hour * 12) + d.cron.Do(func() { + err = d.refreshCookie() + if err != nil { + log.Errorf(ctx, err.Error()) + } + }) + return nil +} + +func (d *ChaoXing) Drop(ctx context.Context) error { + if d.cron != nil { + d.cron.Stop() + } + return nil +} + +func (d *ChaoXing) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.GetFiles(dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *ChaoXing) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var resp DownResp + ua := d.conf.ua + fileId := strings.Split(file.GetID(), "$")[1] + _, err := d.requestDownload("/screen/note_note/files/status/"+fileId, http.MethodPost, func(req *resty.Request) { + req.SetHeader("User-Agent", ua) + }, &resp) + if err != nil { + return nil, err + } + u := resp.Download + return &model.Link{ + URL: u, + Header: http.Header{ + "Cookie": []string{d.Cookie}, + "Referer": []string{d.conf.referer}, + "User-Agent": []string{ua}, + }, + Concurrency: 2, + PartSize: 10 * utils.MB, + }, nil +} + +func (d *ChaoXing) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + query := map[string]string{ + "bbsid": d.Addition.Bbsid, + "name": dirName, + "pid": parentDir.GetID(), + } + var resp ListFileResp + _, err := d.request("/pc/resource/addResourceFolder", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return err + } + if resp.Result != 1 { + msg := fmt.Sprintf("error:%s", resp.Msg) + return errors.New(msg) + } + return nil +} + +func (d *ChaoXing) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + query := map[string]string{ + "bbsid": d.Addition.Bbsid, + "folderIds": srcObj.GetID(), + "targetId": dstDir.GetID(), + } + if !srcObj.IsDir() { + query = map[string]string{ + "bbsid": d.Addition.Bbsid, + "recIds": strings.Split(srcObj.GetID(), "$")[0], + "targetId": dstDir.GetID(), + } + } + var resp ListFileResp + _, err := d.request("/pc/resource/moveResource", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return err + } + if !resp.Status { + msg := fmt.Sprintf("error:%s", resp.Msg) + return errors.New(msg) + } + return nil +} + +func (d *ChaoXing) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + query := map[string]string{ + "bbsid": d.Addition.Bbsid, + "folderId": srcObj.GetID(), + "name": newName, + } + path := "/pc/resource/updateResourceFolderName" + if !srcObj.IsDir() { + // path = "/pc/resource/updateResourceFileName" + // query = map[string]string{ + // "bbsid": d.Addition.Bbsid, + // "recIds": strings.Split(srcObj.GetID(), "$")[0], + // "name": newName, + // } + return errors.New("此网盘不支持修改文件名") + } + var resp ListFileResp + _, err := d.request(path, http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return err + } + if resp.Result != 1 { + msg := fmt.Sprintf("error:%s", resp.Msg) + return errors.New(msg) + } + return nil +} + +func (d *ChaoXing) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + // TODO copy obj, optional + return errs.NotImplement +} + +func (d *ChaoXing) Remove(ctx context.Context, obj model.Obj) error { + query := map[string]string{ + "bbsid": d.Addition.Bbsid, + "folderIds": obj.GetID(), + } + path := "/pc/resource/deleteResourceFolder" + var resp ListFileResp + if !obj.IsDir() { + path = "/pc/resource/deleteResourceFile" + query = map[string]string{ + "bbsid": d.Addition.Bbsid, + "recIds": strings.Split(obj.GetID(), "$")[0], + } + } + _, err := d.request(path, http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return err + } + if resp.Result != 1 { + msg := fmt.Sprintf("error:%s", resp.Msg) + return errors.New(msg) + } + return nil +} + +func (d *ChaoXing) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + var resp UploadDataRsp + _, err := d.request("https://noteyd.chaoxing.com/pc/files/getUploadConfig", http.MethodGet, func(req *resty.Request) { + }, &resp) + if err != nil { + return err + } + if resp.Result != 1 { + return errors.New("get upload data error") + } + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + filePart, err := writer.CreateFormFile("file", stream.GetName()) + if err != nil { + return err + } + _, err = utils.CopyWithBuffer(filePart, stream) + if err != nil { + return err + } + err = writer.WriteField("_token", resp.Msg.Token) + if err != nil { + return err + } + err = writer.WriteField("puid", fmt.Sprintf("%d", resp.Msg.Puid)) + if err != nil { + fmt.Println("Error writing param2 to request body:", err) + return err + } + err = writer.Close() + if err != nil { + return err + } + req, err := http.NewRequest("POST", "https://pan-yz.chaoxing.com/upload", body) + if err != nil { + return err + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Content-Length", fmt.Sprintf("%d", body.Len())) + resps, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resps.Body.Close() + bodys, err := io.ReadAll(resps.Body) + if err != nil { + return err + } + var fileRsp UploadFileDataRsp + err = json.Unmarshal(bodys, &fileRsp) + if err != nil { + return err + } + if fileRsp.Msg != "success" { + return errors.New(fileRsp.Msg) + } + uploadDoneParam := UploadDoneParam{Key: fileRsp.ObjectID, Cataid: "100000019", Param: fileRsp.Data} + params, err := json.Marshal(uploadDoneParam) + if err != nil { + return err + } + query := map[string]string{ + "bbsid": d.Addition.Bbsid, + "pid": dstDir.GetID(), + "type": "yunpan", + "params": url.QueryEscape("[" + string(params) + "]"), + } + var respd ListFileResp + _, err = d.request("/pc/resource/addResource", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &respd) + if err != nil { + return err + } + if respd.Result != 1 { + msg := fmt.Sprintf("error:%v", resp.Msg) + return errors.New(msg) + } + return nil +} + +var _ driver.Driver = (*ChaoXing)(nil) diff --git a/drivers/chaoxing/meta.go b/drivers/chaoxing/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..c0500629cf365c5ceca2a37794f9ef630c318854 --- /dev/null +++ b/drivers/chaoxing/meta.go @@ -0,0 +1,47 @@ +package chaoxing + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +// 此程序挂载的是超星小组网盘,需要代理才能使用; +// 登录超星后进入个人空间,进入小组,新建小组,点击进去。 +// url中就有bbsid的参数,系统限制单文件大小2G,没有总容量限制 +type Addition struct { + // 超星用户名及密码 + UserName string `json:"user_name" required:"true"` + Password string `json:"password" required:"true"` + // 从自己新建的小组url里获取 + Bbsid string `json:"bbsid" required:"true"` + driver.RootID + // 可不填,程序会自动登录获取 + Cookie string `json:"cookie"` +} + +type Conf struct { + ua string + referer string + api string + DowloadApi string +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &ChaoXing{ + config: driver.Config{ + Name: "ChaoXingGroupDrive", + OnlyProxy: true, + OnlyLocal: false, + DefaultRoot: "-1", + NoOverwriteUpload: true, + }, + conf: Conf{ + ua: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) quark-cloud-drive/2.5.20 Chrome/100.0.4896.160 Electron/18.3.5.4-b478491100 Safari/537.36 Channel/pckk_other_ch", + referer: "https://chaoxing.com/", + api: "https://groupweb.chaoxing.com", + DowloadApi: "https://noteyd.chaoxing.com", + }, + } + }) +} diff --git a/drivers/chaoxing/types.go b/drivers/chaoxing/types.go new file mode 100644 index 0000000000000000000000000000000000000000..71a59e15be638886c2181f227ea000ebcda30b79 --- /dev/null +++ b/drivers/chaoxing/types.go @@ -0,0 +1,276 @@ +package chaoxing + +import ( + "bytes" + "fmt" + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type Resp struct { + Result int `json:"result"` +} + +type UserAuth struct { + GroupAuth struct { + AddData int `json:"addData"` + AddDataFolder int `json:"addDataFolder"` + AddLebel int `json:"addLebel"` + AddManager int `json:"addManager"` + AddMem int `json:"addMem"` + AddTopicFolder int `json:"addTopicFolder"` + AnonymousAddReply int `json:"anonymousAddReply"` + AnonymousAddTopic int `json:"anonymousAddTopic"` + BatchOperation int `json:"batchOperation"` + DelData int `json:"delData"` + DelDataFolder int `json:"delDataFolder"` + DelMem int `json:"delMem"` + DelTopicFolder int `json:"delTopicFolder"` + Dismiss int `json:"dismiss"` + ExamEnc string `json:"examEnc"` + GroupChat int `json:"groupChat"` + IsShowCircleChatButton int `json:"isShowCircleChatButton"` + IsShowCircleCloudButton int `json:"isShowCircleCloudButton"` + IsShowCompanyButton int `json:"isShowCompanyButton"` + Join int `json:"join"` + MemberShowRankSet int `json:"memberShowRankSet"` + ModifyDataFolder int `json:"modifyDataFolder"` + ModifyExpose int `json:"modifyExpose"` + ModifyName int `json:"modifyName"` + ModifyShowPic int `json:"modifyShowPic"` + ModifyTopicFolder int `json:"modifyTopicFolder"` + ModifyVisibleState int `json:"modifyVisibleState"` + OnlyMgrScoreSet int `json:"onlyMgrScoreSet"` + Quit int `json:"quit"` + SendNotice int `json:"sendNotice"` + ShowActivityManage int `json:"showActivityManage"` + ShowActivitySet int `json:"showActivitySet"` + ShowAttentionSet int `json:"showAttentionSet"` + ShowAutoClearStatus int `json:"showAutoClearStatus"` + ShowBarcode int `json:"showBarcode"` + ShowChatRoomSet int `json:"showChatRoomSet"` + ShowCircleActivitySet int `json:"showCircleActivitySet"` + ShowCircleSet int `json:"showCircleSet"` + ShowCmem int `json:"showCmem"` + ShowDataFolder int `json:"showDataFolder"` + ShowDelReason int `json:"showDelReason"` + ShowForward int `json:"showForward"` + ShowGroupChat int `json:"showGroupChat"` + ShowGroupChatSet int `json:"showGroupChatSet"` + ShowGroupSquareSet int `json:"showGroupSquareSet"` + ShowLockAddSet int `json:"showLockAddSet"` + ShowManager int `json:"showManager"` + ShowManagerIdentitySet int `json:"showManagerIdentitySet"` + ShowNeedDelReasonSet int `json:"showNeedDelReasonSet"` + ShowNotice int `json:"showNotice"` + ShowOnlyManagerReplySet int `json:"showOnlyManagerReplySet"` + ShowRank int `json:"showRank"` + ShowRank2 int `json:"showRank2"` + ShowRecycleBin int `json:"showRecycleBin"` + ShowReplyByClass int `json:"showReplyByClass"` + ShowReplyNeedCheck int `json:"showReplyNeedCheck"` + ShowSignbanSet int `json:"showSignbanSet"` + ShowSpeechSet int `json:"showSpeechSet"` + ShowTopicCheck int `json:"showTopicCheck"` + ShowTopicNeedCheck int `json:"showTopicNeedCheck"` + ShowTransferSet int `json:"showTransferSet"` + } `json:"groupAuth"` + OperationAuth struct { + Add int `json:"add"` + AddTopicToFolder int `json:"addTopicToFolder"` + ChoiceSet int `json:"choiceSet"` + DelTopicFromFolder int `json:"delTopicFromFolder"` + Delete int `json:"delete"` + Reply int `json:"reply"` + ScoreSet int `json:"scoreSet"` + TopSet int `json:"topSet"` + Update int `json:"update"` + } `json:"operationAuth"` +} + +// 手机端学习通上传的文件的json内容(content字段)与网页端上传的有所不同 +// 网页端json `"puid": 54321, "size": 12345` +// 手机端json `"puid": "54321". "size": "12345"` +type int_str int + +// json 字符串数字和纯数字解析 +func (ios *int_str) UnmarshalJSON(data []byte) error { + intValue, err := strconv.Atoi(string(bytes.Trim(data, "\""))) + if err != nil { + return err + } + *ios = int_str(intValue) + return nil +} + +type File struct { + Cataid int `json:"cataid"` + Cfid int `json:"cfid"` + Content struct { + Cfid int `json:"cfid"` + Pid int `json:"pid"` + FolderName string `json:"folderName"` + ShareType int `json:"shareType"` + Preview string `json:"preview"` + Filetype string `json:"filetype"` + PreviewURL string `json:"previewUrl"` + IsImg bool `json:"isImg"` + ParentPath string `json:"parentPath"` + Icon string `json:"icon"` + Suffix string `json:"suffix"` + Duration int `json:"duration"` + Pantype string `json:"pantype"` + Puid int_str `json:"puid"` + Filepath string `json:"filepath"` + Crc string `json:"crc"` + Isfile bool `json:"isfile"` + Residstr string `json:"residstr"` + ObjectID string `json:"objectId"` + Extinfo string `json:"extinfo"` + Thumbnail string `json:"thumbnail"` + Creator int `json:"creator"` + ResTypeValue int `json:"resTypeValue"` + UploadDateFormat string `json:"uploadDateFormat"` + DisableOpt bool `json:"disableOpt"` + DownPath string `json:"downPath"` + Sort int `json:"sort"` + Topsort int `json:"topsort"` + Restype string `json:"restype"` + Size int_str `json:"size"` + UploadDate int64 `json:"uploadDate"` + FileSize string `json:"fileSize"` + Name string `json:"name"` + FileID string `json:"fileId"` + } `json:"content"` + CreatorID int `json:"creatorId"` + DesID string `json:"des_id"` + ID int `json:"id"` + Inserttime int64 `json:"inserttime"` + Key string `json:"key"` + Norder int `json:"norder"` + OwnerID int `json:"ownerId"` + OwnerType int `json:"ownerType"` + Path string `json:"path"` + Rid int `json:"rid"` + Status int `json:"status"` + Topsign int `json:"topsign"` +} + +type ListFileResp struct { + Msg string `json:"msg"` + Result int `json:"result"` + Status bool `json:"status"` + UserAuth UserAuth `json:"userAuth"` + List []File `json:"list"` +} + +type DownResp struct { + Msg string `json:"msg"` + Duration int `json:"duration"` + Download string `json:"download"` + FileStatus string `json:"fileStatus"` + URL string `json:"url"` + Status bool `json:"status"` +} + +type UploadDataRsp struct { + Result int `json:"result"` + Msg struct { + Puid int `json:"puid"` + Token string `json:"token"` + } `json:"msg"` +} + +type UploadFileDataRsp struct { + Result bool `json:"result"` + Msg string `json:"msg"` + Crc string `json:"crc"` + ObjectID string `json:"objectId"` + Resid int64 `json:"resid"` + Puid int `json:"puid"` + Data struct { + DisableOpt bool `json:"disableOpt"` + Resid int64 `json:"resid"` + Crc string `json:"crc"` + Puid int `json:"puid"` + Isfile bool `json:"isfile"` + Pantype string `json:"pantype"` + Size int `json:"size"` + Name string `json:"name"` + ObjectID string `json:"objectId"` + Restype string `json:"restype"` + UploadDate int64 `json:"uploadDate"` + ModifyDate int64 `json:"modifyDate"` + UploadDateFormat string `json:"uploadDateFormat"` + Residstr string `json:"residstr"` + Suffix string `json:"suffix"` + Preview string `json:"preview"` + Thumbnail string `json:"thumbnail"` + Creator int `json:"creator"` + Duration int `json:"duration"` + IsImg bool `json:"isImg"` + PreviewURL string `json:"previewUrl"` + Filetype string `json:"filetype"` + Filepath string `json:"filepath"` + Sort int `json:"sort"` + Topsort int `json:"topsort"` + ResTypeValue int `json:"resTypeValue"` + Extinfo string `json:"extinfo"` + } `json:"data"` +} + +type UploadDoneParam struct { + Cataid string `json:"cataid"` + Key string `json:"key"` + Param struct { + DisableOpt bool `json:"disableOpt"` + Resid int64 `json:"resid"` + Crc string `json:"crc"` + Puid int `json:"puid"` + Isfile bool `json:"isfile"` + Pantype string `json:"pantype"` + Size int `json:"size"` + Name string `json:"name"` + ObjectID string `json:"objectId"` + Restype string `json:"restype"` + UploadDate int64 `json:"uploadDate"` + ModifyDate int64 `json:"modifyDate"` + UploadDateFormat string `json:"uploadDateFormat"` + Residstr string `json:"residstr"` + Suffix string `json:"suffix"` + Preview string `json:"preview"` + Thumbnail string `json:"thumbnail"` + Creator int `json:"creator"` + Duration int `json:"duration"` + IsImg bool `json:"isImg"` + PreviewURL string `json:"previewUrl"` + Filetype string `json:"filetype"` + Filepath string `json:"filepath"` + Sort int `json:"sort"` + Topsort int `json:"topsort"` + ResTypeValue int `json:"resTypeValue"` + Extinfo string `json:"extinfo"` + } `json:"param"` +} + +func fileToObj(f File) *model.Object { + if len(f.Content.FolderName) > 0 { + return &model.Object{ + ID: fmt.Sprintf("%d", f.ID), + Name: f.Content.FolderName, + Size: 0, + Modified: time.UnixMilli(f.Inserttime), + IsFolder: true, + } + } + paserTime := time.UnixMilli(f.Content.UploadDate) + return &model.Object{ + ID: fmt.Sprintf("%d$%s", f.ID, f.Content.FileID), + Name: f.Content.Name, + Size: int64(f.Content.Size), + Modified: paserTime, + IsFolder: false, + } +} diff --git a/drivers/chaoxing/util.go b/drivers/chaoxing/util.go new file mode 100644 index 0000000000000000000000000000000000000000..b6725804e6c1addfc51b48b12a889e6b1c7e382a --- /dev/null +++ b/drivers/chaoxing/util.go @@ -0,0 +1,183 @@ +package chaoxing + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "errors" + "fmt" + "mime/multipart" + "net/http" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/go-resty/resty/v2" +) + +func (d *ChaoXing) requestDownload(pathname string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + u := d.conf.DowloadApi + pathname + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "Cookie": d.Cookie, + "Accept": "application/json, text/plain, */*", + "Referer": d.conf.referer, + }) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e Resp + req.SetError(&e) + res, err := req.Execute(method, u) + if err != nil { + return nil, err + } + return res.Body(), nil +} + +func (d *ChaoXing) request(pathname string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + u := d.conf.api + pathname + if strings.Contains(pathname, "getUploadConfig") { + u = pathname + } + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "Cookie": d.Cookie, + "Accept": "application/json, text/plain, */*", + "Referer": d.conf.referer, + }) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e Resp + req.SetError(&e) + res, err := req.Execute(method, u) + if err != nil { + return nil, err + } + return res.Body(), nil +} + +func (d *ChaoXing) GetFiles(parent string) ([]File, error) { + files := make([]File, 0) + query := map[string]string{ + "bbsid": d.Addition.Bbsid, + "folderId": parent, + "recType": "1", + } + var resp ListFileResp + _, err := d.request("/pc/resource/getResourceList", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return nil, err + } + if resp.Result != 1 { + msg := fmt.Sprintf("error code is:%d", resp.Result) + return nil, errors.New(msg) + } + if len(resp.List) > 0 { + files = append(files, resp.List...) + } + querys := map[string]string{ + "bbsid": d.Addition.Bbsid, + "folderId": parent, + "recType": "2", + } + var resps ListFileResp + _, err = d.request("/pc/resource/getResourceList", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(querys) + }, &resps) + if err != nil { + return nil, err + } + for _, file := range resps.List { + // 手机端超星上传的文件没有fileID字段,但ObjectID与fileID相同,可代替 + if file.Content.FileID == "" { + file.Content.FileID = file.Content.ObjectID + } + files = append(files, file) + } + return files, nil +} + +func EncryptByAES(message, key string) (string, error) { + aesKey := []byte(key) + plainText := []byte(message) + block, err := aes.NewCipher(aesKey) + if err != nil { + return "", err + } + iv := aesKey[:aes.BlockSize] + mode := cipher.NewCBCEncrypter(block, iv) + padding := aes.BlockSize - len(plainText)%aes.BlockSize + paddedText := append(plainText, byte(padding)) + for i := 0; i < padding-1; i++ { + paddedText = append(paddedText, byte(padding)) + } + ciphertext := make([]byte, len(paddedText)) + mode.CryptBlocks(ciphertext, paddedText) + encrypted := base64.StdEncoding.EncodeToString(ciphertext) + return encrypted, nil +} + +func CookiesToString(cookies []*http.Cookie) string { + var cookieStr string + for _, cookie := range cookies { + cookieStr += cookie.Name + "=" + cookie.Value + "; " + } + if len(cookieStr) > 2 { + cookieStr = cookieStr[:len(cookieStr)-2] + } + return cookieStr +} + +func (d *ChaoXing) Login() (string, error) { + transferKey := "u2oh6Vu^HWe4_AES" + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + uname, err := EncryptByAES(d.Addition.UserName, transferKey) + if err != nil { + return "", err + } + password, err := EncryptByAES(d.Addition.Password, transferKey) + if err != nil { + return "", err + } + err = writer.WriteField("uname", uname) + if err != nil { + return "", err + } + err = writer.WriteField("password", password) + if err != nil { + return "", err + } + err = writer.WriteField("t", "true") + if err != nil { + return "", err + } + err = writer.Close() + if err != nil { + return "", err + } + // Create the request + req, err := http.NewRequest("POST", "https://passport2.chaoxing.com/fanyalogin", body) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Content-Length", fmt.Sprintf("%d", body.Len())) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + return CookiesToString(resp.Cookies()), nil + +} diff --git a/drivers/cloudreve/driver.go b/drivers/cloudreve/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..8fc117aca2c0b0d0351f0e17c5389085d09af528 --- /dev/null +++ b/drivers/cloudreve/driver.go @@ -0,0 +1,224 @@ +package cloudreve + +import ( + "context" + "io" + "net/http" + "path" + "strconv" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +type Cloudreve struct { + model.Storage + Addition +} + +func (d *Cloudreve) Config() driver.Config { + return config +} + +func (d *Cloudreve) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Cloudreve) Init(ctx context.Context) error { + if d.Cookie != "" { + return nil + } + // removing trailing slash + d.Address = strings.TrimSuffix(d.Address, "/") + return d.login() +} + +func (d *Cloudreve) Drop(ctx context.Context) error { + d.Cookie = "" + return nil +} + +func (d *Cloudreve) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + var r DirectoryResp + err := d.request(http.MethodGet, "/directory"+dir.GetPath(), nil, &r) + if err != nil { + return nil, err + } + + return utils.SliceConvert(r.Objects, func(src Object) (model.Obj, error) { + thumb, err := d.GetThumb(src) + if err != nil { + return nil, err + } + if src.Type == "dir" && d.EnableThumbAndFolderSize { + var dprop DirectoryProp + err = d.request(http.MethodGet, "/object/property/"+src.Id+"?is_folder=true", nil, &dprop) + if err != nil { + return nil, err + } + src.Size = dprop.Size + } + return objectToObj(src, thumb), nil + }) +} + +func (d *Cloudreve) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var dUrl string + err := d.request(http.MethodPut, "/file/download/"+file.GetID(), nil, &dUrl) + if err != nil { + return nil, err + } + if strings.HasPrefix(dUrl, "/api") { + dUrl = d.Address + dUrl + } + return &model.Link{ + URL: dUrl, + }, nil +} + +func (d *Cloudreve) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + return d.request(http.MethodPut, "/directory", func(req *resty.Request) { + req.SetBody(base.Json{ + "path": parentDir.GetPath() + "/" + dirName, + }) + }, nil) +} + +func (d *Cloudreve) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + body := base.Json{ + "action": "move", + "src_dir": path.Dir(srcObj.GetPath()), + "dst": dstDir.GetPath(), + "src": convertSrc(srcObj), + } + return d.request(http.MethodPatch, "/object", func(req *resty.Request) { + req.SetBody(body) + }, nil) +} + +func (d *Cloudreve) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + body := base.Json{ + "action": "rename", + "new_name": newName, + "src": convertSrc(srcObj), + } + return d.request(http.MethodPatch, "/object/rename", func(req *resty.Request) { + req.SetBody(body) + }, nil) +} + +func (d *Cloudreve) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + body := base.Json{ + "src_dir": path.Dir(srcObj.GetPath()), + "dst": dstDir.GetPath(), + "src": convertSrc(srcObj), + } + return d.request(http.MethodPost, "/object/copy", func(req *resty.Request) { + req.SetBody(body) + }, nil) +} + +func (d *Cloudreve) Remove(ctx context.Context, obj model.Obj) error { + body := convertSrc(obj) + err := d.request(http.MethodDelete, "/object", func(req *resty.Request) { + req.SetBody(body) + }, nil) + return err +} + +func (d *Cloudreve) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + if io.ReadCloser(stream) == http.NoBody { + return d.create(ctx, dstDir, stream) + } + + // 获取存储策略 + var r DirectoryResp + err := d.request(http.MethodGet, "/directory"+dstDir.GetPath(), nil, &r) + if err != nil { + return err + } + uploadBody := base.Json{ + "path": dstDir.GetPath(), + "size": stream.GetSize(), + "name": stream.GetName(), + "policy_id": r.Policy.Id, + "last_modified": stream.ModTime().Unix(), + } + + // 获取上传会话信息 + var u UploadInfo + err = d.request(http.MethodPut, "/file/upload", func(req *resty.Request) { + req.SetBody(uploadBody) + }, &u) + if err != nil { + return err + } + + // 根据存储方式选择分片上传的方法 + switch r.Policy.Type { + case "onedrive": + err = d.upOneDrive(ctx, stream, u, up) + case "remote": // 从机存储 + err = d.upRemote(ctx, stream, u, up) + case "local": // 本机存储 + var chunkSize = u.ChunkSize + var buf []byte + var chunk int + for { + var n int + buf = make([]byte, chunkSize) + n, err = io.ReadAtLeast(stream, buf, chunkSize) + if err != nil && err != io.ErrUnexpectedEOF { + if err == io.EOF { + return nil + } + return err + } + if n == 0 { + break + } + buf = buf[:n] + err = d.request(http.MethodPost, "/file/upload/"+u.SessionID+"/"+strconv.Itoa(chunk), func(req *resty.Request) { + req.SetHeader("Content-Type", "application/octet-stream") + req.SetHeader("Content-Length", strconv.Itoa(n)) + req.SetBody(buf) + }, nil) + if err != nil { + break + } + chunk++ + } + default: + err = errs.NotImplement + } + if err != nil { + // 删除失败的会话 + err = d.request(http.MethodDelete, "/file/upload/"+u.SessionID, nil, nil) + return err + } + return nil +} + +func (d *Cloudreve) create(ctx context.Context, dir model.Obj, file model.Obj) error { + body := base.Json{"path": dir.GetPath() + "/" + file.GetName()} + if file.IsDir() { + err := d.request(http.MethodPut, "directory", func(req *resty.Request) { + req.SetBody(body) + }, nil) + return err + } + return d.request(http.MethodPost, "/file/create", func(req *resty.Request) { + req.SetBody(body) + }, nil) +} + +//func (d *Cloudreve) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*Cloudreve)(nil) diff --git a/drivers/cloudreve/meta.go b/drivers/cloudreve/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..92c0b9fb1d7ba24aa07a2a131d3d3b4974981c81 --- /dev/null +++ b/drivers/cloudreve/meta.go @@ -0,0 +1,29 @@ +package cloudreve + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + // Usually one of two + driver.RootPath + // define other + Address string `json:"address" required:"true"` + Username string `json:"username"` + Password string `json:"password"` + Cookie string `json:"cookie"` + CustomUA string `json:"custom_ua"` + EnableThumbAndFolderSize bool `json:"enable_thumb_and_folder_size"` +} + +var config = driver.Config{ + Name: "Cloudreve", + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Cloudreve{} + }) +} diff --git a/drivers/cloudreve/types.go b/drivers/cloudreve/types.go new file mode 100644 index 0000000000000000000000000000000000000000..a7c3919e8a97d51ca2e48c588da80904e3e11e8a --- /dev/null +++ b/drivers/cloudreve/types.go @@ -0,0 +1,69 @@ +package cloudreve + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type Resp struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data interface{} `json:"data"` +} + +type Policy struct { + Id string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + MaxSize int `json:"max_size"` + FileType []string `json:"file_type"` +} + +type UploadInfo struct { + SessionID string `json:"sessionID"` + ChunkSize int `json:"chunkSize"` + Expires int `json:"expires"` + UploadURLs []string `json:"uploadURLs"` + Credential string `json:"credential,omitempty"` +} + +type DirectoryResp struct { + Parent string `json:"parent"` + Objects []Object `json:"objects"` + Policy Policy `json:"policy"` +} + +type Object struct { + Id string `json:"id"` + Name string `json:"name"` + Path string `json:"path"` + Pic string `json:"pic"` + Size int `json:"size"` + Type string `json:"type"` + Date time.Time `json:"date"` + CreateDate time.Time `json:"create_date"` + SourceEnabled bool `json:"source_enabled"` +} + +type DirectoryProp struct { + Size int `json:"size"` +} + +func objectToObj(f Object, t model.Thumbnail) *model.ObjThumb { + return &model.ObjThumb{ + Object: model.Object{ + ID: f.Id, + Name: f.Name, + Size: int64(f.Size), + Modified: f.Date, + IsFolder: f.Type == "dir", + }, + Thumbnail: t, + } +} + +type Config struct { + LoginCaptcha bool `json:"loginCaptcha"` + CaptchaType string `json:"captcha_type"` +} diff --git a/drivers/cloudreve/util.go b/drivers/cloudreve/util.go new file mode 100644 index 0000000000000000000000000000000000000000..b5b71153e129f96746bf54d944fec0c4737eaf3c --- /dev/null +++ b/drivers/cloudreve/util.go @@ -0,0 +1,273 @@ +package cloudreve + +import ( + "bytes" + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/cookie" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + json "github.com/json-iterator/go" + jsoniter "github.com/json-iterator/go" +) + +// do others that not defined in Driver interface + +const loginPath = "/user/session" + +func (d *Cloudreve) request(method string, path string, callback base.ReqCallback, out interface{}) error { + u := d.Address + "/api/v3" + path + ua := d.CustomUA + if ua == "" { + ua = base.UserAgent + } + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "Cookie": "cloudreve-session=" + d.Cookie, + "Accept": "application/json, text/plain, */*", + "User-Agent": ua, + }) + + var r Resp + req.SetResult(&r) + + if callback != nil { + callback(req) + } + + resp, err := req.Execute(method, u) + if err != nil { + return err + } + if !resp.IsSuccess() { + return errors.New(resp.String()) + } + + if r.Code != 0 { + + // 刷新 cookie + if r.Code == http.StatusUnauthorized && path != loginPath { + if d.Username != "" && d.Password != "" { + err = d.login() + if err != nil { + return err + } + return d.request(method, path, callback, out) + } + } + + return errors.New(r.Msg) + } + sess := cookie.GetCookie(resp.Cookies(), "cloudreve-session") + if sess != nil { + d.Cookie = sess.Value + } + if out != nil && r.Data != nil { + var marshal []byte + marshal, err = json.Marshal(r.Data) + if err != nil { + return err + } + err = json.Unmarshal(marshal, out) + if err != nil { + return err + } + } + + return nil +} + +func (d *Cloudreve) login() error { + var siteConfig Config + err := d.request(http.MethodGet, "/site/config", nil, &siteConfig) + if err != nil { + return err + } + for i := 0; i < 5; i++ { + err = d.doLogin(siteConfig.LoginCaptcha) + if err == nil { + break + } + if err != nil && err.Error() != "CAPTCHA not match." { + break + } + } + return err +} + +func (d *Cloudreve) doLogin(needCaptcha bool) error { + var captchaCode string + var err error + if needCaptcha { + var captcha string + err = d.request(http.MethodGet, "/site/captcha", nil, &captcha) + if err != nil { + return err + } + if len(captcha) == 0 { + return errors.New("can not get captcha") + } + i := strings.Index(captcha, ",") + dec := base64.NewDecoder(base64.StdEncoding, strings.NewReader(captcha[i+1:])) + vRes, err := base.RestyClient.R().SetMultipartField( + "image", "validateCode.png", "image/png", dec). + Post(setting.GetStr(conf.OcrApi)) + if err != nil { + return err + } + if jsoniter.Get(vRes.Body(), "status").ToInt() != 200 { + return errors.New("ocr error:" + jsoniter.Get(vRes.Body(), "msg").ToString()) + } + captchaCode = jsoniter.Get(vRes.Body(), "result").ToString() + } + var resp Resp + err = d.request(http.MethodPost, loginPath, func(req *resty.Request) { + req.SetBody(base.Json{ + "username": d.Addition.Username, + "Password": d.Addition.Password, + "captchaCode": captchaCode, + }) + }, &resp) + return err +} + +func convertSrc(obj model.Obj) map[string]interface{} { + m := make(map[string]interface{}) + var dirs []string + var items []string + if obj.IsDir() { + dirs = append(dirs, obj.GetID()) + } else { + items = append(items, obj.GetID()) + } + m["dirs"] = dirs + m["items"] = items + return m +} + +func (d *Cloudreve) GetThumb(file Object) (model.Thumbnail, error) { + if !d.Addition.EnableThumbAndFolderSize { + return model.Thumbnail{}, nil + } + ua := d.CustomUA + if ua == "" { + ua = base.UserAgent + } + req := base.NoRedirectClient.R() + req.SetHeaders(map[string]string{ + "Cookie": "cloudreve-session=" + d.Cookie, + "Accept": "image/webp,image/apng,image/svg+xml,image/*,*/*;q=0.8", + "User-Agent": ua, + }) + resp, err := req.Execute(http.MethodGet, d.Address+"/api/v3/file/thumb/"+file.Id) + if err != nil { + return model.Thumbnail{}, err + } + return model.Thumbnail{ + Thumbnail: resp.Header().Get("Location"), + }, nil +} + +func (d *Cloudreve) upRemote(ctx context.Context, stream model.FileStreamer, u UploadInfo, up driver.UpdateProgress) error { + uploadUrl := u.UploadURLs[0] + credential := u.Credential + var finish int64 = 0 + var chunk int = 0 + DEFAULT := int64(u.ChunkSize) + for finish < stream.GetSize() { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + utils.Log.Debugf("[Cloudreve-Remote] upload: %d", finish) + var byteSize = DEFAULT + left := stream.GetSize() - finish + if left < DEFAULT { + byteSize = left + } + byteData := make([]byte, byteSize) + n, err := io.ReadFull(stream, byteData) + utils.Log.Debug(err, n) + if err != nil { + return err + } + req, err := http.NewRequest("POST", uploadUrl+"?chunk="+strconv.Itoa(chunk), bytes.NewBuffer(byteData)) + if err != nil { + return err + } + req = req.WithContext(ctx) + req.Header.Set("Content-Length", strconv.Itoa(int(byteSize))) + req.Header.Set("Authorization", fmt.Sprint(credential)) + finish += byteSize + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + res.Body.Close() + up(float64(finish) * 100 / float64(stream.GetSize())) + chunk++ + } + return nil +} + +func (d *Cloudreve) upOneDrive(ctx context.Context, stream model.FileStreamer, u UploadInfo, up driver.UpdateProgress) error { + uploadUrl := u.UploadURLs[0] + var finish int64 = 0 + DEFAULT := int64(u.ChunkSize) + for finish < stream.GetSize() { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + utils.Log.Debugf("[Cloudreve-OneDrive] upload: %d", finish) + var byteSize = DEFAULT + left := stream.GetSize() - finish + if left < DEFAULT { + byteSize = left + } + byteData := make([]byte, byteSize) + n, err := io.ReadFull(stream, byteData) + utils.Log.Debug(err, n) + if err != nil { + return err + } + req, err := http.NewRequest("PUT", uploadUrl, bytes.NewBuffer(byteData)) + if err != nil { + return err + } + req = req.WithContext(ctx) + req.Header.Set("Content-Length", strconv.Itoa(int(byteSize))) + req.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", finish, finish+byteSize-1, stream.GetSize())) + finish += byteSize + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + // https://learn.microsoft.com/zh-cn/onedrive/developer/rest-api/api/driveitem_createuploadsession + if res.StatusCode != 201 && res.StatusCode != 202 && res.StatusCode != 200 { + data, _ := io.ReadAll(res.Body) + res.Body.Close() + return errors.New(string(data)) + } + res.Body.Close() + up(float64(finish) * 100 / float64(stream.GetSize())) + } + // 上传成功发送回调请求 + err := d.request(http.MethodPost, "/callback/onedrive/finish/"+u.SessionID, func(req *resty.Request) { + req.SetBody("{}") + }, nil) + if err != nil { + return err + } + return nil +} diff --git a/drivers/crypt/driver.go b/drivers/crypt/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..b6115896b98f94d6691459ca547dcc06071ba5b5 --- /dev/null +++ b/drivers/crypt/driver.go @@ -0,0 +1,414 @@ +package crypt + +import ( + "context" + "fmt" + "io" + stdpath "path" + "regexp" + "strings" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/sign" + "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + rcCrypt "github.com/rclone/rclone/backend/crypt" + "github.com/rclone/rclone/fs/config/configmap" + "github.com/rclone/rclone/fs/config/obscure" + log "github.com/sirupsen/logrus" +) + +type Crypt struct { + model.Storage + Addition + cipher *rcCrypt.Cipher + remoteStorage driver.Driver +} + +const obfuscatedPrefix = "___Obfuscated___" + +func (d *Crypt) Config() driver.Config { + return config +} + +func (d *Crypt) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Crypt) Init(ctx context.Context) error { + //obfuscate credentials if it's updated or just created + err := d.updateObfusParm(&d.Password) + if err != nil { + return fmt.Errorf("failed to obfuscate password: %w", err) + } + err = d.updateObfusParm(&d.Salt) + if err != nil { + return fmt.Errorf("failed to obfuscate salt: %w", err) + } + + isCryptExt := regexp.MustCompile(`^[.][A-Za-z0-9-_]{2,}$`).MatchString + if !isCryptExt(d.EncryptedSuffix) { + return fmt.Errorf("EncryptedSuffix is Illegal") + } + d.FileNameEncoding = utils.GetNoneEmpty(d.FileNameEncoding, "base64") + d.EncryptedSuffix = utils.GetNoneEmpty(d.EncryptedSuffix, ".bin") + + op.MustSaveDriverStorage(d) + + //need remote storage exist + storage, err := fs.GetStorage(d.RemotePath, &fs.GetStoragesArgs{}) + if err != nil { + return fmt.Errorf("can't find remote storage: %w", err) + } + d.remoteStorage = storage + + p, _ := strings.CutPrefix(d.Password, obfuscatedPrefix) + p2, _ := strings.CutPrefix(d.Salt, obfuscatedPrefix) + config := configmap.Simple{ + "password": p, + "password2": p2, + "filename_encryption": d.FileNameEnc, + "directory_name_encryption": d.DirNameEnc, + "filename_encoding": d.FileNameEncoding, + "suffix": d.EncryptedSuffix, + "pass_bad_blocks": "", + } + c, err := rcCrypt.NewCipher(config) + if err != nil { + return fmt.Errorf("failed to create Cipher: %w", err) + } + d.cipher = c + + return nil +} + +func (d *Crypt) updateObfusParm(str *string) error { + temp := *str + if !strings.HasPrefix(temp, obfuscatedPrefix) { + temp, err := obscure.Obscure(temp) + if err != nil { + return err + } + temp = obfuscatedPrefix + temp + *str = temp + } + return nil +} + +func (d *Crypt) Drop(ctx context.Context) error { + return nil +} + +func (d *Crypt) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + path := dir.GetPath() + //return d.list(ctx, d.RemotePath, path) + //remoteFull + + objs, err := fs.List(ctx, d.getPathForRemote(path, true), &fs.ListArgs{NoLog: true}) + // the obj must implement the model.SetPath interface + // return objs, err + if err != nil { + return nil, err + } + + var result []model.Obj + for _, obj := range objs { + if obj.IsDir() { + name, err := d.cipher.DecryptDirName(obj.GetName()) + if err != nil { + //filter illegal files + continue + } + if !d.ShowHidden && strings.HasPrefix(name, ".") { + continue + } + objRes := model.Object{ + Name: name, + Size: 0, + Modified: obj.ModTime(), + IsFolder: obj.IsDir(), + Ctime: obj.CreateTime(), + // discarding hash as it's encrypted + } + result = append(result, &objRes) + } else { + thumb, ok := model.GetThumb(obj) + size, err := d.cipher.DecryptedSize(obj.GetSize()) + if err != nil { + //filter illegal files + continue + } + name, err := d.cipher.DecryptFileName(obj.GetName()) + if err != nil { + //filter illegal files + continue + } + if !d.ShowHidden && strings.HasPrefix(name, ".") { + continue + } + objRes := model.Object{ + Name: name, + Size: size, + Modified: obj.ModTime(), + IsFolder: obj.IsDir(), + Ctime: obj.CreateTime(), + // discarding hash as it's encrypted + } + if d.Thumbnail && thumb == "" { + thumbPath := stdpath.Join(args.ReqPath, ".thumbnails", name+".webp") + thumb = fmt.Sprintf("%s/d%s?sign=%s", + common.GetApiUrl(common.GetHttpReq(ctx)), + utils.EncodePath(thumbPath, true), + sign.Sign(thumbPath)) + } + if !ok && !d.Thumbnail { + result = append(result, &objRes) + } else { + objWithThumb := model.ObjThumb{ + Object: objRes, + Thumbnail: model.Thumbnail{ + Thumbnail: thumb, + }, + } + result = append(result, &objWithThumb) + } + } + } + + return result, nil +} + +func (d *Crypt) Get(ctx context.Context, path string) (model.Obj, error) { + if utils.PathEqual(path, "/") { + return &model.Object{ + Name: "Root", + IsFolder: true, + Path: "/", + }, nil + } + remoteFullPath := "" + var remoteObj model.Obj + var err, err2 error + firstTryIsFolder, secondTry := guessPath(path) + remoteFullPath = d.getPathForRemote(path, firstTryIsFolder) + remoteObj, err = fs.Get(ctx, remoteFullPath, &fs.GetArgs{NoLog: true}) + if err != nil { + if errs.IsObjectNotFound(err) && secondTry { + //try the opposite + remoteFullPath = d.getPathForRemote(path, !firstTryIsFolder) + remoteObj, err2 = fs.Get(ctx, remoteFullPath, &fs.GetArgs{NoLog: true}) + if err2 != nil { + return nil, err2 + } + } else { + return nil, err + } + } + var size int64 = 0 + name := "" + if !remoteObj.IsDir() { + size, err = d.cipher.DecryptedSize(remoteObj.GetSize()) + if err != nil { + log.Warnf("DecryptedSize failed for %s ,will use original size, err:%s", path, err) + size = remoteObj.GetSize() + } + name, err = d.cipher.DecryptFileName(remoteObj.GetName()) + if err != nil { + log.Warnf("DecryptFileName failed for %s ,will use original name, err:%s", path, err) + name = remoteObj.GetName() + } + } else { + name, err = d.cipher.DecryptDirName(remoteObj.GetName()) + if err != nil { + log.Warnf("DecryptDirName failed for %s ,will use original name, err:%s", path, err) + name = remoteObj.GetName() + } + } + obj := &model.Object{ + Path: path, + Name: name, + Size: size, + Modified: remoteObj.ModTime(), + IsFolder: remoteObj.IsDir(), + } + return obj, nil + //return nil, errs.ObjectNotFound +} + +func (d *Crypt) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + dstDirActualPath, err := d.getActualPathForRemote(file.GetPath(), false) + if err != nil { + return nil, fmt.Errorf("failed to convert path to remote path: %w", err) + } + remoteLink, remoteFile, err := op.Link(ctx, d.remoteStorage, dstDirActualPath, args) + if err != nil { + return nil, err + } + + if remoteLink.RangeReadCloser == nil && remoteLink.MFile == nil && len(remoteLink.URL) == 0 { + return nil, fmt.Errorf("the remote storage driver need to be enhanced to support encrytion") + } + remoteFileSize := remoteFile.GetSize() + remoteClosers := utils.EmptyClosers() + rangeReaderFunc := func(ctx context.Context, underlyingOffset, underlyingLength int64) (io.ReadCloser, error) { + length := underlyingLength + if underlyingLength >= 0 && underlyingOffset+underlyingLength >= remoteFileSize { + length = -1 + } + rrc := remoteLink.RangeReadCloser + if len(remoteLink.URL) > 0 { + + rangedRemoteLink := &model.Link{ + URL: remoteLink.URL, + Header: remoteLink.Header, + } + var converted, err = stream.GetRangeReadCloserFromLink(remoteFileSize, rangedRemoteLink) + if err != nil { + return nil, err + } + rrc = converted + } + if rrc != nil { + //remoteRangeReader, err := + remoteReader, err := rrc.RangeRead(ctx, http_range.Range{Start: underlyingOffset, Length: length}) + remoteClosers.AddClosers(rrc.GetClosers()) + if err != nil { + return nil, err + } + return remoteReader, nil + } + if remoteLink.MFile != nil { + _, err := remoteLink.MFile.Seek(underlyingOffset, io.SeekStart) + if err != nil { + return nil, err + } + //remoteClosers.Add(remoteLink.MFile) + //keep reuse same MFile and close at last. + remoteClosers.Add(remoteLink.MFile) + return io.NopCloser(remoteLink.MFile), nil + } + + return nil, errs.NotSupport + + } + resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + readSeeker, err := d.cipher.DecryptDataSeek(ctx, rangeReaderFunc, httpRange.Start, httpRange.Length) + if err != nil { + return nil, err + } + return readSeeker, nil + } + + resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closers: remoteClosers} + resultLink := &model.Link{ + Header: remoteLink.Header, + RangeReadCloser: resultRangeReadCloser, + Expiration: remoteLink.Expiration, + } + + return resultLink, nil + +} + +func (d *Crypt) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + dstDirActualPath, err := d.getActualPathForRemote(parentDir.GetPath(), true) + if err != nil { + return fmt.Errorf("failed to convert path to remote path: %w", err) + } + dir := d.cipher.EncryptDirName(dirName) + return op.MakeDir(ctx, d.remoteStorage, stdpath.Join(dstDirActualPath, dir)) +} + +func (d *Crypt) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + srcRemoteActualPath, err := d.getActualPathForRemote(srcObj.GetPath(), srcObj.IsDir()) + if err != nil { + return fmt.Errorf("failed to convert path to remote path: %w", err) + } + dstRemoteActualPath, err := d.getActualPathForRemote(dstDir.GetPath(), dstDir.IsDir()) + if err != nil { + return fmt.Errorf("failed to convert path to remote path: %w", err) + } + return op.Move(ctx, d.remoteStorage, srcRemoteActualPath, dstRemoteActualPath) +} + +func (d *Crypt) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + remoteActualPath, err := d.getActualPathForRemote(srcObj.GetPath(), srcObj.IsDir()) + if err != nil { + return fmt.Errorf("failed to convert path to remote path: %w", err) + } + var newEncryptedName string + if srcObj.IsDir() { + newEncryptedName = d.cipher.EncryptDirName(newName) + } else { + newEncryptedName = d.cipher.EncryptFileName(newName) + } + return op.Rename(ctx, d.remoteStorage, remoteActualPath, newEncryptedName) +} + +func (d *Crypt) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + srcRemoteActualPath, err := d.getActualPathForRemote(srcObj.GetPath(), srcObj.IsDir()) + if err != nil { + return fmt.Errorf("failed to convert path to remote path: %w", err) + } + dstRemoteActualPath, err := d.getActualPathForRemote(dstDir.GetPath(), dstDir.IsDir()) + if err != nil { + return fmt.Errorf("failed to convert path to remote path: %w", err) + } + return op.Copy(ctx, d.remoteStorage, srcRemoteActualPath, dstRemoteActualPath) + +} + +func (d *Crypt) Remove(ctx context.Context, obj model.Obj) error { + remoteActualPath, err := d.getActualPathForRemote(obj.GetPath(), obj.IsDir()) + if err != nil { + return fmt.Errorf("failed to convert path to remote path: %w", err) + } + return op.Remove(ctx, d.remoteStorage, remoteActualPath) +} + +func (d *Crypt) Put(ctx context.Context, dstDir model.Obj, streamer model.FileStreamer, up driver.UpdateProgress) error { + dstDirActualPath, err := d.getActualPathForRemote(dstDir.GetPath(), true) + if err != nil { + return fmt.Errorf("failed to convert path to remote path: %w", err) + } + + // Encrypt the data into wrappedIn + wrappedIn, err := d.cipher.EncryptData(streamer) + if err != nil { + return fmt.Errorf("failed to EncryptData: %w", err) + } + + // doesn't support seekableStream, since rapid-upload is not working for encrypted data + streamOut := &stream.FileStream{ + Obj: &model.Object{ + ID: streamer.GetID(), + Path: streamer.GetPath(), + Name: d.cipher.EncryptFileName(streamer.GetName()), + Size: d.cipher.EncryptedSize(streamer.GetSize()), + Modified: streamer.ModTime(), + IsFolder: streamer.IsDir(), + }, + Reader: wrappedIn, + Mimetype: "application/octet-stream", + WebPutAsTask: streamer.NeedStore(), + ForceStreamUpload: true, + Exist: streamer.GetExist(), + } + err = op.Put(ctx, d.remoteStorage, dstDirActualPath, streamOut, up, false) + if err != nil { + return err + } + return nil +} + +//func (d *Safe) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*Crypt)(nil) diff --git a/drivers/crypt/meta.go b/drivers/crypt/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..180773a3f48329095a59b7f11dd047b4ef4c2707 --- /dev/null +++ b/drivers/crypt/meta.go @@ -0,0 +1,46 @@ +package crypt + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + // Usually one of two + //driver.RootPath + //driver.RootID + // define other + + FileNameEnc string `json:"filename_encryption" type:"select" required:"true" options:"off,standard,obfuscate" default:"off"` + DirNameEnc string `json:"directory_name_encryption" type:"select" required:"true" options:"false,true" default:"false"` + RemotePath string `json:"remote_path" required:"true" help:"This is where the encrypted data stores"` + + Password string `json:"password" required:"true" confidential:"true" help:"the main password"` + Salt string `json:"salt" confidential:"true" help:"If you don't know what is salt, treat it as a second password. Optional but recommended"` + EncryptedSuffix string `json:"encrypted_suffix" required:"true" default:".bin" help:"for advanced user only! encrypted files will have this suffix"` + FileNameEncoding string `json:"filename_encoding" type:"select" required:"true" options:"base64,base32,base32768" default:"base64" help:"for advanced user only!"` + + Thumbnail bool `json:"thumbnail" required:"true" default:"false" help:"enable thumbnail which pre-generated under .thumbnails folder"` + + ShowHidden bool `json:"show_hidden" default:"true" required:"false" help:"show hidden directories and files"` +} + +var config = driver.Config{ + Name: "Crypt", + LocalSort: true, + OnlyLocal: false, + OnlyProxy: true, + NoCache: true, + NoUpload: false, + NeedMs: false, + DefaultRoot: "/", + CheckStatus: false, + Alert: "", + NoOverwriteUpload: false, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Crypt{} + }) +} diff --git a/drivers/crypt/types.go b/drivers/crypt/types.go new file mode 100644 index 0000000000000000000000000000000000000000..283fd7b6ab5b08db7ae897473143ae1a7a6bbe68 --- /dev/null +++ b/drivers/crypt/types.go @@ -0,0 +1 @@ +package crypt diff --git a/drivers/crypt/util.go b/drivers/crypt/util.go new file mode 100644 index 0000000000000000000000000000000000000000..3e55fb37ac1c6729e3e3b9a47c9e1a9f3782d251 --- /dev/null +++ b/drivers/crypt/util.go @@ -0,0 +1,44 @@ +package crypt + +import ( + stdpath "path" + "path/filepath" + "strings" + + "github.com/alist-org/alist/v3/internal/op" +) + +// will give the best guessing based on the path +func guessPath(path string) (isFolder, secondTry bool) { + if strings.HasSuffix(path, "/") { + //confirmed a folder + return true, false + } + lastSlash := strings.LastIndex(path, "/") + if strings.Index(path[lastSlash:], ".") < 0 { + //no dot, try folder then try file + return true, true + } + return false, true +} + +func (d *Crypt) getPathForRemote(path string, isFolder bool) (remoteFullPath string) { + if isFolder && !strings.HasSuffix(path, "/") { + path = path + "/" + } + dir, fileName := filepath.Split(path) + + remoteDir := d.cipher.EncryptDirName(dir) + remoteFileName := "" + if len(strings.TrimSpace(fileName)) > 0 { + remoteFileName = d.cipher.EncryptFileName(fileName) + } + return stdpath.Join(d.RemotePath, remoteDir, remoteFileName) + +} + +// actual path is used for internal only. any link for user should come from remoteFullPath +func (d *Crypt) getActualPathForRemote(path string, isFolder bool) (string, error) { + _, remoteActualPath, err := op.GetStorageAndActualPath(d.getPathForRemote(path, isFolder)) + return remoteActualPath, err +} diff --git a/drivers/dropbox/driver.go b/drivers/dropbox/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..9b1717b04d9aaeb3339961f8e095e6078ec6fa2a --- /dev/null +++ b/drivers/dropbox/driver.go @@ -0,0 +1,240 @@ +package dropbox + +import ( + "context" + "fmt" + "io" + "math" + "net/http" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +type Dropbox struct { + model.Storage + Addition + base string + contentBase string +} + +func (d *Dropbox) Config() driver.Config { + return config +} + +func (d *Dropbox) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Dropbox) Init(ctx context.Context) error { + query := "foo" + res, err := d.request("/2/check/user", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "query": query, + }) + }) + if err != nil { + return err + } + result := utils.Json.Get(res, "result").ToString() + if result != query { + return fmt.Errorf("failed to check user: %s", string(res)) + } + d.RootNamespaceId, err = d.GetRootNamespaceId(ctx) + + return err +} + +func (d *Dropbox) GetRootNamespaceId(ctx context.Context) (string, error) { + res, err := d.request("/2/users/get_current_account", http.MethodPost, func(req *resty.Request) { + req.SetBody(nil) + }) + if err != nil { + return "", err + } + var currentAccountResp CurrentAccountResp + err = utils.Json.Unmarshal(res, ¤tAccountResp) + if err != nil { + return "", err + } + rootNamespaceId := currentAccountResp.RootInfo.RootNamespaceId + return rootNamespaceId, nil +} + +func (d *Dropbox) Drop(ctx context.Context) error { + return nil +} + +func (d *Dropbox) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(ctx, dir.GetPath()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *Dropbox) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + res, err := d.request("/2/files/get_temporary_link", http.MethodPost, func(req *resty.Request) { + req.SetContext(ctx).SetBody(base.Json{ + "path": file.GetPath(), + }) + }) + if err != nil { + return nil, err + } + url := utils.Json.Get(res, "link").ToString() + exp := time.Hour + return &model.Link{ + URL: url, + Expiration: &exp, + }, nil +} + +func (d *Dropbox) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + _, err := d.request("/2/files/create_folder_v2", http.MethodPost, func(req *resty.Request) { + req.SetContext(ctx).SetBody(base.Json{ + "autorename": false, + "path": parentDir.GetPath() + "/" + dirName, + }) + }) + return err +} + +func (d *Dropbox) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + toPath := dstDir.GetPath() + "/" + srcObj.GetName() + + _, err := d.request("/2/files/move_v2", http.MethodPost, func(req *resty.Request) { + req.SetContext(ctx).SetBody(base.Json{ + "allow_ownership_transfer": false, + "allow_shared_folder": false, + "autorename": false, + "from_path": srcObj.GetID(), + "to_path": toPath, + }) + }) + return err +} + +func (d *Dropbox) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + path := srcObj.GetPath() + fileName := srcObj.GetName() + toPath := path[:len(path)-len(fileName)] + newName + + _, err := d.request("/2/files/move_v2", http.MethodPost, func(req *resty.Request) { + req.SetContext(ctx).SetBody(base.Json{ + "allow_ownership_transfer": false, + "allow_shared_folder": false, + "autorename": false, + "from_path": srcObj.GetID(), + "to_path": toPath, + }) + }) + return err +} + +func (d *Dropbox) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + toPath := dstDir.GetPath() + "/" + srcObj.GetName() + _, err := d.request("/2/files/copy_v2", http.MethodPost, func(req *resty.Request) { + req.SetContext(ctx).SetBody(base.Json{ + "allow_ownership_transfer": false, + "allow_shared_folder": false, + "autorename": false, + "from_path": srcObj.GetID(), + "to_path": toPath, + }) + }) + return err +} + +func (d *Dropbox) Remove(ctx context.Context, obj model.Obj) error { + uri := "/2/files/delete_v2" + _, err := d.request(uri, http.MethodPost, func(req *resty.Request) { + req.SetContext(ctx).SetBody(base.Json{ + "path": obj.GetID(), + }) + }) + return err +} + +func (d *Dropbox) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + // 1. start + sessionId, err := d.startUploadSession(ctx) + if err != nil { + return err + } + + // 2.append + // A single request should not upload more than 150 MB, and each call must be multiple of 4MB (except for last call) + const PartSize = 20971520 + count := 1 + if stream.GetSize() > PartSize { + count = int(math.Ceil(float64(stream.GetSize()) / float64(PartSize))) + } + offset := int64(0) + + for i := 0; i < count; i++ { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + + start := i * PartSize + byteSize := stream.GetSize() - int64(start) + if byteSize > PartSize { + byteSize = PartSize + } + + url := d.contentBase + "/2/files/upload_session/append_v2" + reader := io.LimitReader(stream, PartSize) + req, err := http.NewRequest(http.MethodPost, url, reader) + if err != nil { + log.Errorf("failed to update file when append to upload session, err: %+v", err) + return err + } + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Authorization", "Bearer "+d.AccessToken) + + args := UploadAppendArgs{ + Close: false, + Cursor: UploadCursor{ + Offset: offset, + SessionID: sessionId, + }, + } + argsJson, err := utils.Json.MarshalToString(args) + if err != nil { + return err + } + req.Header.Set("Dropbox-API-Arg", argsJson) + + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + _ = res.Body.Close() + + if count > 0 { + up(float64(i+1) * 100 / float64(count)) + } + + offset += byteSize + + } + // 3.finish + toPath := dstDir.GetPath() + "/" + stream.GetName() + err2 := d.finishUploadSession(ctx, toPath, offset, sessionId) + if err2 != nil { + return err2 + } + + return err +} + +var _ driver.Driver = (*Dropbox)(nil) diff --git a/drivers/dropbox/meta.go b/drivers/dropbox/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..6e7bc014790fa99a5f6a5186ddb796a9c29c5bcf --- /dev/null +++ b/drivers/dropbox/meta.go @@ -0,0 +1,43 @@ +package dropbox + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +const ( + DefaultClientID = "76lrwrklhdn1icb" +) + +type Addition struct { + RefreshToken string `json:"refresh_token" required:"true"` + driver.RootPath + + OauthTokenURL string `json:"oauth_token_url" default:"https://api.xhofe.top/alist/dropbox/token"` + ClientID string `json:"client_id" required:"false" help:"Keep it empty if you don't have one"` + ClientSecret string `json:"client_secret" required:"false" help:"Keep it empty if you don't have one"` + + AccessToken string + RootNamespaceId string +} + +var config = driver.Config{ + Name: "Dropbox", + LocalSort: false, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: false, + NeedMs: false, + DefaultRoot: "", + NoOverwriteUpload: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Dropbox{ + base: "https://api.dropboxapi.com", + contentBase: "https://content.dropboxapi.com", + } + }) +} diff --git a/drivers/dropbox/types.go b/drivers/dropbox/types.go new file mode 100644 index 0000000000000000000000000000000000000000..f2ec4cb7e7d8538687e785828224ef0ae9607085 --- /dev/null +++ b/drivers/dropbox/types.go @@ -0,0 +1,86 @@ +package dropbox + +import ( + "github.com/alist-org/alist/v3/internal/model" + "time" +) + +type TokenResp struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` +} + +type ErrorResp struct { + Error struct { + Tag string `json:".tag"` + } `json:"error"` + ErrorSummary string `json:"error_summary"` +} + +type RefreshTokenErrorResp struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +type CurrentAccountResp struct { + RootInfo struct { + RootNamespaceId string `json:"root_namespace_id"` + HomeNamespaceId string `json:"home_namespace_id"` + } `json:"root_info"` +} + +type File struct { + Tag string `json:".tag"` + Name string `json:"name"` + PathLower string `json:"path_lower"` + PathDisplay string `json:"path_display"` + ID string `json:"id"` + ClientModified time.Time `json:"client_modified"` + ServerModified time.Time `json:"server_modified"` + Rev string `json:"rev"` + Size int `json:"size"` + IsDownloadable bool `json:"is_downloadable"` + ContentHash string `json:"content_hash"` +} + +type ListResp struct { + Entries []File `json:"entries"` + Cursor string `json:"cursor"` + HasMore bool `json:"has_more"` +} + +type UploadCursor struct { + Offset int64 `json:"offset"` + SessionID string `json:"session_id"` +} + +type UploadAppendArgs struct { + Close bool `json:"close"` + Cursor UploadCursor `json:"cursor"` +} + +type UploadFinishArgs struct { + Commit struct { + Autorename bool `json:"autorename"` + Mode string `json:"mode"` + Mute bool `json:"mute"` + Path string `json:"path"` + StrictConflict bool `json:"strict_conflict"` + } `json:"commit"` + Cursor UploadCursor `json:"cursor"` +} + +func fileToObj(f File) *model.ObjThumb { + return &model.ObjThumb{ + Object: model.Object{ + ID: f.ID, + Path: f.PathDisplay, + Name: f.Name, + Size: int64(f.Size), + Modified: f.ServerModified, + IsFolder: f.Tag == "folder", + }, + Thumbnail: model.Thumbnail{}, + } +} diff --git a/drivers/dropbox/util.go b/drivers/dropbox/util.go new file mode 100644 index 0000000000000000000000000000000000000000..5065f08d394926afdcc1d7fdc46228641b858c9a --- /dev/null +++ b/drivers/dropbox/util.go @@ -0,0 +1,209 @@ +package dropbox + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +func (d *Dropbox) refreshToken() error { + url := d.base + "/oauth2/token" + if utils.SliceContains([]string{"", DefaultClientID}, d.ClientID) { + url = d.OauthTokenURL + } + var tokenResp TokenResp + resp, err := base.RestyClient.R(). + //ForceContentType("application/x-www-form-urlencoded"). + //SetBasicAuth(d.ClientID, d.ClientSecret). + SetFormData(map[string]string{ + "grant_type": "refresh_token", + "refresh_token": d.RefreshToken, + "client_id": d.ClientID, + "client_secret": d.ClientSecret, + }). + Post(url) + if err != nil { + return err + } + log.Debugf("[dropbox] refresh token response: %s", resp.String()) + if resp.StatusCode() != 200 { + return fmt.Errorf("failed to refresh token: %s", resp.String()) + } + _ = utils.Json.UnmarshalFromString(resp.String(), &tokenResp) + d.AccessToken = tokenResp.AccessToken + op.MustSaveDriverStorage(d) + return nil +} + +func (d *Dropbox) request(uri, method string, callback base.ReqCallback, retry ...bool) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeader("Authorization", "Bearer "+d.AccessToken) + if d.RootNamespaceId != "" { + apiPathRootJson, err := utils.Json.MarshalToString(map[string]interface{}{ + ".tag": "root", + "root": d.RootNamespaceId, + }) + if err != nil { + return nil, err + } + req.SetHeader("Dropbox-API-Path-Root", apiPathRootJson) + } + if callback != nil { + callback(req) + } + if method == http.MethodPost && req.Body != nil { + req.SetHeader("Content-Type", "application/json") + } + var e ErrorResp + req.SetError(&e) + res, err := req.Execute(method, d.base+uri) + if err != nil { + return nil, err + } + log.Debugf("[dropbox] request (%s) response: %s", uri, res.String()) + isRetry := len(retry) > 0 && retry[0] + if res.StatusCode() != 200 { + body := res.String() + if !isRetry && (utils.SliceMeet([]string{"expired_access_token", "invalid_access_token", "authorization"}, body, + func(item string, v string) bool { + return strings.Contains(v, item) + }) || d.AccessToken == "") { + err = d.refreshToken() + if err != nil { + return nil, err + } + return d.request(uri, method, callback, true) + } + return nil, fmt.Errorf("%s:%s", e.Error, e.ErrorSummary) + } + return res.Body(), nil +} + +func (d *Dropbox) list(ctx context.Context, data base.Json, isContinue bool) (*ListResp, error) { + var resp ListResp + uri := "/2/files/list_folder" + if isContinue { + uri += "/continue" + } + _, err := d.request(uri, http.MethodPost, func(req *resty.Request) { + req.SetContext(ctx).SetBody(data).SetResult(&resp) + }) + if err != nil { + return nil, err + } + return &resp, nil +} + +func (d *Dropbox) getFiles(ctx context.Context, path string) ([]File, error) { + hasMore := true + var marker string + res := make([]File, 0) + + data := base.Json{ + "include_deleted": false, + "include_has_explicit_shared_members": false, + "include_mounted_folders": false, + "include_non_downloadable_files": false, + "limit": 2000, + "path": path, + "recursive": false, + } + resp, err := d.list(ctx, data, false) + if err != nil { + return nil, err + } + marker = resp.Cursor + hasMore = resp.HasMore + res = append(res, resp.Entries...) + + for hasMore { + data := base.Json{ + "cursor": marker, + } + resp, err := d.list(ctx, data, true) + if err != nil { + return nil, err + } + marker = resp.Cursor + hasMore = resp.HasMore + res = append(res, resp.Entries...) + } + return res, nil +} + +func (d *Dropbox) finishUploadSession(ctx context.Context, toPath string, offset int64, sessionId string) error { + url := d.contentBase + "/2/files/upload_session/finish" + req, err := http.NewRequest(http.MethodPost, url, nil) + if err != nil { + return err + } + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Authorization", "Bearer "+d.AccessToken) + + uploadFinishArgs := UploadFinishArgs{ + Commit: struct { + Autorename bool `json:"autorename"` + Mode string `json:"mode"` + Mute bool `json:"mute"` + Path string `json:"path"` + StrictConflict bool `json:"strict_conflict"` + }{ + Autorename: true, + Mode: "add", + Mute: false, + Path: toPath, + StrictConflict: false, + }, + Cursor: UploadCursor{ + Offset: offset, + SessionID: sessionId, + }, + } + + argsJson, err := utils.Json.MarshalToString(uploadFinishArgs) + if err != nil { + return err + } + req.Header.Set("Dropbox-API-Arg", argsJson) + + res, err := base.HttpClient.Do(req) + if err != nil { + log.Errorf("failed to update file when finish session, err: %+v", err) + return err + } + _ = res.Body.Close() + return nil +} + +func (d *Dropbox) startUploadSession(ctx context.Context) (string, error) { + url := d.contentBase + "/2/files/upload_session/start" + req, err := http.NewRequest(http.MethodPost, url, nil) + if err != nil { + return "", err + } + req = req.WithContext(ctx) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Authorization", "Bearer "+d.AccessToken) + req.Header.Set("Dropbox-API-Arg", "{\"close\":false}") + + res, err := base.HttpClient.Do(req) + if err != nil { + log.Errorf("failed to update file when start session, err: %+v", err) + return "", err + } + + body, err := io.ReadAll(res.Body) + sessionId := utils.Json.Get(body, "session_id").ToString() + + _ = res.Body.Close() + return sessionId, nil +} diff --git a/drivers/fastwebdav/driver.go b/drivers/fastwebdav/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..f92b5a6e558375d5656d75a817761f9145c67548 --- /dev/null +++ b/drivers/fastwebdav/driver.go @@ -0,0 +1,220 @@ +package fastwebdav + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +type FastWebdav struct { + model.Storage + Addition +} + +func (d *FastWebdav) Config() driver.Config { + return config +} + +func (d *FastWebdav) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *FastWebdav) Init(ctx context.Context) error { + d.Address = strings.TrimSuffix(d.Address, "/") + return nil +} + +func (d *FastWebdav) Drop(ctx context.Context) error { + return nil +} + +func (d *FastWebdav) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(dir.GetPath(), dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *FastWebdav) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var dUrl string + b, _ := base64.StdEncoding.DecodeString(file.GetID()) + + var f File + _ = json.Unmarshal(b, &f) + url := f.Provider + "/url" + + if len(f.DownloadUrl) > 4 { + dUrl = f.DownloadUrl + } else { + err := d.request(http.MethodPost, url, func(req *resty.Request) { + req.SetBody(f) + }, &dUrl) + if err != nil { + return nil, err + } + } + + // if strings.HasPrefix(dUrl, "/api") { + // dUrl = d.Address + dUrl + // } + + link := &model.Link{ + URL: dUrl, + } + + if len(f.PlayHeaders) > 4 { + var headers map[string]string + err := json.Unmarshal([]byte(f.PlayHeaders), &headers) + if err != nil { + fmt.Println("无法解析自定义Header:", err) + return link, err + } + + // 将 map 转换为 http.Header + header := make(http.Header) + for key, value := range headers { + header.Add(key, value) + } + + link.Header = header + } + return link, nil +} + +func (d *FastWebdav) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + return d.request(http.MethodPut, "/directory", func(req *resty.Request) { + req.SetBody(base.Json{ + "path": parentDir.GetPath() + "/" + dirName, + }) + }, nil) +} + +func (d *FastWebdav) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + body := base.Json{ + "action": "move", + "src_dir": srcObj.GetPath(), + "dst": dstDir.GetPath(), + "src": convertSrc(srcObj), + } + return d.request(http.MethodPatch, "/object", func(req *resty.Request) { + req.SetBody(body) + }, nil) +} + +func (d *FastWebdav) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + body := base.Json{ + "action": "rename", + "new_name": newName, + "src": convertSrc(srcObj), + } + return d.request(http.MethodPatch, "/object/rename", func(req *resty.Request) { + req.SetBody(body) + }, nil) +} + +func (d *FastWebdav) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + body := base.Json{ + "src_dir": srcObj.GetPath(), + "dst": dstDir.GetPath(), + "src": convertSrc(srcObj), + } + return d.request(http.MethodPost, "/object/copy", func(req *resty.Request) { + req.SetBody(body) + }, nil) +} + +func (d *FastWebdav) Remove(ctx context.Context, obj model.Obj) error { + body := convertSrc(obj) + err := d.request(http.MethodDelete, "/object", func(req *resty.Request) { + req.SetBody(body) + }, nil) + return err +} + +func (d *FastWebdav) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + if io.ReadCloser(stream) == http.NoBody { + return d.create(ctx, dstDir, stream) + } + var r DirectoryResp + err := d.request(http.MethodGet, "/directory"+dstDir.GetPath(), nil, &r) + if err != nil { + return err + } + uploadBody := base.Json{ + "path": dstDir.GetPath(), + "size": stream.GetSize(), + "name": stream.GetName(), + "policy_id": r.Policy.Id, + "last_modified": stream.ModTime().Unix(), + } + var u UploadInfo + err = d.request(http.MethodPut, "/file/upload", func(req *resty.Request) { + req.SetBody(uploadBody) + }, &u) + if err != nil { + return err + } + var chunkSize = u.ChunkSize + var buf []byte + var chunk int + for { + var n int + buf = make([]byte, chunkSize) + n, err = io.ReadAtLeast(stream, buf, chunkSize) + if err != nil && err != io.ErrUnexpectedEOF { + if err == io.EOF { + return nil + } + return err + } + + if n == 0 { + break + } + buf = buf[:n] + err = d.request(http.MethodPost, "/file/upload/"+u.SessionID+"/"+strconv.Itoa(chunk), func(req *resty.Request) { + req.SetHeader("Content-Type", "application/octet-stream") + req.SetHeader("Content-Length", strconv.Itoa(n)) + req.SetBody(buf) + }, nil) + if err != nil { + break + } + chunk++ + + } + return err +} + +func (d *FastWebdav) create(ctx context.Context, dir model.Obj, file model.Obj) error { + body := base.Json{"path": dir.GetPath() + "/" + file.GetName()} + if file.IsDir() { + err := d.request(http.MethodPut, "directory", func(req *resty.Request) { + req.SetBody(body) + }, nil) + return err + } + return d.request(http.MethodPost, "/file/create", func(req *resty.Request) { + req.SetBody(body) + }, nil) +} + +//func (d *FastWebdav) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*FastWebdav)(nil) diff --git a/drivers/fastwebdav/meta.go b/drivers/fastwebdav/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..7e79dc8dc04a648ebb77ec490eba85256103c988 --- /dev/null +++ b/drivers/fastwebdav/meta.go @@ -0,0 +1,25 @@ +package fastwebdav + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + Address string `json:"address" required:"true"` + APIKey string `json:"password" required:"true"` +} + +var config = driver.Config{ + Name: "FastWebdav", + DefaultRoot: "/", + NoUpload: true, + Alert: "warning|只支持读文件不能进行其它操作,如:复制,移动,上传等", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &FastWebdav{} + }) +} diff --git a/drivers/fastwebdav/types.go b/drivers/fastwebdav/types.go new file mode 100644 index 0000000000000000000000000000000000000000..c98de8658bc72fa41adbf4345de0c801b6994938 --- /dev/null +++ b/drivers/fastwebdav/types.go @@ -0,0 +1,168 @@ +package fastwebdav + +import ( + "encoding/base64" + "encoding/json" + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" +) + +type Resp struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data interface{} `json:"data"` +} + +type Policy struct { + Id string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + MaxSize int `json:"max_size"` + FileType []string `json:"file_type"` +} + +type UploadInfo struct { + SessionID string `json:"sessionID"` + ChunkSize int `json:"chunkSize"` + Expires int `json:"expires"` +} + +type DirectoryResp struct { + Parent string `json:"parent"` + Objects []Object `json:"objects"` + Policy Policy `json:"policy"` +} + +type Object struct { + Id string `json:"id"` + Name string `json:"name"` + Path string `json:"path"` + Pic string `json:"pic"` + Size int `json:"size"` + Type string `json:"type"` + Date time.Time `json:"date"` + CreateDate time.Time `json:"create_date"` + SourceEnabled bool `json:"source_enabled"` +} + +type DirectoryProp struct { + Size int `json:"size"` +} + +func objectToObj(f Object, t model.Thumbnail) *model.ObjThumb { + return &model.ObjThumb{ + Object: model.Object{ + ID: f.Id, + Name: f.Name, + Size: int64(f.Size), + Modified: f.Date, + IsFolder: f.Type == "dir", + }, + Thumbnail: t, + } +} + +type Config struct { + LoginCaptcha bool `json:"loginCaptcha"` + CaptchaType string `json:"captcha_type"` +} + +type File struct { + Id string `json:"id"` + Kind int `json:"kind"` + Provider string `json:"provider"` + Name string `json:"name"` + CreateTime string `json:"create_time"` + Sha1 string `json:"sha1"` + Size string `json:"size"` + ParentId string `json:"parent_id"` + Oriname string `json:"oriname"` + DownloadUrl string `json:"download_url"` + PlayHeaders string `json:"play_headers"` + Password string `json:"password"` +} + +func fileToObj(f File) *model.Object { + size, _ := strconv.ParseInt(f.Size, 10, 64) + create_time, _ := time.Parse("2006-01-02 15:04:05", f.CreateTime) + b, _ := json.Marshal(f) + file_id := base64.StdEncoding.EncodeToString(b) + file := &model.Object{ + ID: file_id, + Name: f.Name, + Size: size, + Ctime: create_time, + Modified: create_time, + IsFolder: f.Kind == 0, + HashInfo: utils.NewHashInfo(hash_extend.GCID, f.Sha1), + } + + // if len(f.DownloadUrl) > 4 { + // file.Url = model.Url{Url: f.DownloadUrl} + // } + + return file +} + +// Node is a node in the folder tree +type Node struct { + Url string + Name string + Level int + Modified int64 + Size int64 + Children []*Node +} + +func (node *Node) getByPath(paths []string) *Node { + if len(paths) == 0 || node == nil { + return nil + } + if node.Name != paths[0] { + return nil + } + if len(paths) == 1 { + return node + } + for _, child := range node.Children { + tmp := child.getByPath(paths[1:]) + if tmp != nil { + return tmp + } + } + return nil +} + +func (node *Node) isFile() bool { + return node.Url != "" +} + +func (node *Node) calSize() int64 { + if node.isFile() { + return node.Size + } + var size int64 = 0 + for _, child := range node.Children { + size += child.calSize() + } + node.Size = size + return size +} + +func nodeToObj(node *Node, path string) (model.Obj, error) { + if node == nil { + return nil, errs.ObjectNotFound + } + return &model.Object{ + Name: node.Name, + Size: node.Size, + Modified: time.Unix(node.Modified, 0), + IsFolder: !node.isFile(), + Path: path, + }, nil +} diff --git a/drivers/fastwebdav/util.go b/drivers/fastwebdav/util.go new file mode 100644 index 0000000000000000000000000000000000000000..c47d48b396bca3f100a56b1712b582493eb99d61 --- /dev/null +++ b/drivers/fastwebdav/util.go @@ -0,0 +1,102 @@ +package fastwebdav + +import ( + "encoding/base64" + "encoding/json" + "errors" + "net/http" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/model" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +// do others that not defined in Driver interface + +const loginPath = "/user/session" + +func (d *FastWebdav) request(method string, path string, callback base.ReqCallback, resp interface{}) error { + d.Address = strings.TrimSuffix(d.Address, "/") + u := d.Address + "/" + path + + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "X-Space-App-Key": d.APIKey, + "Accept": "application/json, text/plain, */*", + "Content-Type": "application/json", + }) + + if callback != nil { + callback(req) + } + + if resp != nil { + req.SetResult(resp) + } + + res, err := req.Execute(method, u) + if err != nil { + return err + } + if !res.IsSuccess() { + return errors.New(res.String()) + } + + return nil +} + +func convertSrc(obj model.Obj) map[string]interface{} { + m := make(map[string]interface{}) + var dirs []string + var items []string + if obj.IsDir() { + dirs = append(dirs, obj.GetID()) + } else { + items = append(items, obj.GetID()) + } + m["dirs"] = dirs + m["items"] = items + return m +} + +func (d *FastWebdav) getFiles(path string, id string) ([]File, error) { + url := "" + body := base.Json{} + httpMethod := http.MethodGet + + if path != "/" { + provider := getProvider(path) + url = provider + "/list" + log.Debug(url) + httpMethod = http.MethodPost + b, _ := base64.StdEncoding.DecodeString(id) + var f File + _ = json.Unmarshal(b, &f) + body = base.Json{ + "path_str": path, + "parent_file_id": f.Id, + } + } + + res := make([]File, 0) + var resp []File + err := d.request(httpMethod, url, func(req *resty.Request) { + req.SetBody(body) + }, &resp) + if err != nil { + return nil, err + } + res = append(res, resp...) + return res, nil +} + +func getProvider(s string) string { + if strings.Count(s, "/") >= 2 { + start := strings.Index(s, "/") + end := strings.Index(s[start+1:], "/") + start + 1 + return s[start+1 : end] + } + return s[1:] +} diff --git a/drivers/febbox/driver.go b/drivers/febbox/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..55c3aa211fe04d725747de6630a09fd9b5013a02 --- /dev/null +++ b/drivers/febbox/driver.go @@ -0,0 +1,132 @@ +package febbox + +import ( + "context" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" +) + +type FebBox struct { + model.Storage + Addition + accessToken string + oauth2Token oauth2.TokenSource +} + +func (d *FebBox) Config() driver.Config { + return config +} + +func (d *FebBox) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *FebBox) Init(ctx context.Context) error { + // 初始化 oauth2Config + oauth2Config := &clientcredentials.Config{ + ClientID: d.ClientID, + ClientSecret: d.ClientSecret, + AuthStyle: oauth2.AuthStyleInParams, + TokenURL: "https://api.febbox.com/oauth/token", + } + + d.initializeOAuth2Token(ctx, oauth2Config, d.Addition.RefreshToken) + + token, err := d.oauth2Token.Token() + if err != nil { + return err + } + d.accessToken = token.AccessToken + d.Addition.RefreshToken = token.RefreshToken + op.MustSaveDriverStorage(d) + + return nil +} + +func (d *FebBox) Drop(ctx context.Context) error { + return nil +} + +func (d *FebBox) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFilesList(dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *FebBox) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var ip string + if d.Addition.UserIP != "" { + ip = d.Addition.UserIP + } else { + ip = args.IP + } + + url, err := d.getDownloadLink(file.GetID(), ip) + if err != nil { + return nil, err + } + return &model.Link{ + URL: url, + }, nil +} + +func (d *FebBox) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + err := d.makeDir(parentDir.GetID(), dirName) + if err != nil { + return nil, err + } + + return nil, nil +} + +func (d *FebBox) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + err := d.move(srcObj.GetID(), dstDir.GetID()) + if err != nil { + return nil, err + } + + return nil, nil +} + +func (d *FebBox) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + err := d.rename(srcObj.GetID(), newName) + if err != nil { + return nil, err + } + + return nil, nil +} + +func (d *FebBox) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + err := d.copy(srcObj.GetID(), dstDir.GetID()) + if err != nil { + return nil, err + } + + return nil, nil +} + +func (d *FebBox) Remove(ctx context.Context, obj model.Obj) error { + err := d.remove(obj.GetID()) + if err != nil { + return err + } + + return nil +} + +func (d *FebBox) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + return nil, errs.NotImplement +} + +var _ driver.Driver = (*FebBox)(nil) diff --git a/drivers/febbox/meta.go b/drivers/febbox/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..1daeeea8e52b4221d3047d2cff91f80fbce3690f --- /dev/null +++ b/drivers/febbox/meta.go @@ -0,0 +1,36 @@ +package febbox + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootID + ClientID string `json:"client_id" required:"true" default:""` + ClientSecret string `json:"client_secret" required:"true" default:""` + RefreshToken string + SortRule string `json:"sort_rule" required:"true" type:"select" options:"size_asc,size_desc,name_asc,name_desc,update_asc,update_desc,ext_asc,ext_desc" default:"name_asc"` + PageSize int64 `json:"page_size" required:"true" type:"number" default:"100" help:"list api per page size of FebBox driver"` + UserIP string `json:"user_ip" default:"" help:"user ip address for download link which can speed up the download"` +} + +var config = driver.Config{ + Name: "FebBox", + LocalSort: false, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: true, + NeedMs: false, + DefaultRoot: "0", + CheckStatus: false, + Alert: "", + NoOverwriteUpload: false, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &FebBox{} + }) +} diff --git a/drivers/febbox/oauth2.go b/drivers/febbox/oauth2.go new file mode 100644 index 0000000000000000000000000000000000000000..6345d1a711ee66b727f365155a3c3b1afa85c9e5 --- /dev/null +++ b/drivers/febbox/oauth2.go @@ -0,0 +1,88 @@ +package febbox + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/url" + "strings" + "time" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" +) + +type customTokenSource struct { + config *clientcredentials.Config + ctx context.Context + refreshToken string +} + +func (c *customTokenSource) Token() (*oauth2.Token, error) { + v := url.Values{} + if c.refreshToken != "" { + v.Set("grant_type", "refresh_token") + v.Set("refresh_token", c.refreshToken) + } else { + v.Set("grant_type", "client_credentials") + } + + v.Set("client_id", c.config.ClientID) + v.Set("client_secret", c.config.ClientSecret) + + req, err := http.NewRequest("POST", c.config.TokenURL, strings.NewReader(v.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := http.DefaultClient.Do(req.WithContext(c.ctx)) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, errors.New("oauth2: cannot fetch token") + } + + var tokenResp struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + RefreshToken string `json:"refresh_token"` + } `json:"data"` + } + + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return nil, err + } + + if tokenResp.Code != 1 { + return nil, errors.New("oauth2: server response error") + } + + c.refreshToken = tokenResp.Data.RefreshToken + + token := &oauth2.Token{ + AccessToken: tokenResp.Data.AccessToken, + TokenType: tokenResp.Data.TokenType, + RefreshToken: tokenResp.Data.RefreshToken, + Expiry: time.Now().Add(time.Duration(tokenResp.Data.ExpiresIn) * time.Second), + } + + return token, nil +} + +func (d *FebBox) initializeOAuth2Token(ctx context.Context, oauth2Config *clientcredentials.Config, refreshToken string) { + d.oauth2Token = oauth2.ReuseTokenSource(nil, &customTokenSource{ + config: oauth2Config, + ctx: ctx, + refreshToken: refreshToken, + }) +} diff --git a/drivers/febbox/types.go b/drivers/febbox/types.go new file mode 100644 index 0000000000000000000000000000000000000000..2ac6d6b76cc8784a1d34e42b7e5098883e5da02c --- /dev/null +++ b/drivers/febbox/types.go @@ -0,0 +1,123 @@ +package febbox + +import ( + "fmt" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" + "strconv" + "time" +) + +type ErrResp struct { + ErrorCode int64 `json:"code"` + ErrorMsg string `json:"msg"` + ServerRunTime float64 `json:"server_runtime"` + ServerName string `json:"server_name"` +} + +func (e *ErrResp) IsError() bool { + return e.ErrorCode != 0 || e.ErrorMsg != "" || e.ServerRunTime != 0 || e.ServerName != "" +} + +func (e *ErrResp) Error() string { + return fmt.Sprintf("ErrorCode: %d ,Error: %s ,ServerRunTime: %f ,ServerName: %s", e.ErrorCode, e.ErrorMsg, e.ServerRunTime, e.ServerName) +} + +type FileListResp struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + FileList []File `json:"file_list"` + ShowType string `json:"show_type"` + } `json:"data"` +} + +type Rules struct { + AllowCopy int64 `json:"allow_copy"` + AllowDelete int64 `json:"allow_delete"` + AllowDownload int64 `json:"allow_download"` + AllowComment int64 `json:"allow_comment"` + HideLocation int64 `json:"hide_location"` +} + +type File struct { + Fid int64 `json:"fid"` + UID int64 `json:"uid"` + FileSize int64 `json:"file_size"` + Path string `json:"path"` + FileName string `json:"file_name"` + Ext string `json:"ext"` + AddTime int64 `json:"add_time"` + FileCreateTime int64 `json:"file_create_time"` + FileUpdateTime int64 `json:"file_update_time"` + ParentID int64 `json:"parent_id"` + UpdateTime int64 `json:"update_time"` + LastOpenTime int64 `json:"last_open_time"` + IsDir int64 `json:"is_dir"` + Epub int64 `json:"epub"` + IsMusicList int64 `json:"is_music_list"` + OssFid int64 `json:"oss_fid"` + Faststart int64 `json:"faststart"` + HasVideoQuality int64 `json:"has_video_quality"` + TotalDownload int64 `json:"total_download"` + Status int64 `json:"status"` + Remark string `json:"remark"` + OldHash string `json:"old_hash"` + Hash string `json:"hash"` + HashType string `json:"hash_type"` + FromUID int64 `json:"from_uid"` + FidOrg int64 `json:"fid_org"` + ShareID int64 `json:"share_id"` + InvitePermission int64 `json:"invite_permission"` + ThumbSmall string `json:"thumb_small"` + ThumbSmallWidth int64 `json:"thumb_small_width"` + ThumbSmallHeight int64 `json:"thumb_small_height"` + Thumb string `json:"thumb"` + ThumbWidth int64 `json:"thumb_width"` + ThumbHeight int64 `json:"thumb_height"` + ThumbBig string `json:"thumb_big"` + ThumbBigWidth int64 `json:"thumb_big_width"` + ThumbBigHeight int64 `json:"thumb_big_height"` + IsCustomThumb int64 `json:"is_custom_thumb"` + Photos int64 `json:"photos"` + IsAlbum int64 `json:"is_album"` + ReadOnly int64 `json:"read_only"` + Rules Rules `json:"rules"` + IsShared int64 `json:"is_shared"` +} + +func fileToObj(f File) *model.ObjThumb { + return &model.ObjThumb{ + Object: model.Object{ + ID: strconv.FormatInt(f.Fid, 10), + Name: f.FileName, + Size: f.FileSize, + Ctime: time.Unix(f.FileCreateTime, 0), + Modified: time.Unix(f.FileUpdateTime, 0), + IsFolder: f.IsDir == 1, + HashInfo: utils.NewHashInfo(hash_extend.GCID, f.Hash), + }, + Thumbnail: model.Thumbnail{ + Thumbnail: f.Thumb, + }, + } +} + +type FileDownloadResp struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data []struct { + Error int `json:"error"` + DownloadURL string `json:"download_url"` + Hash string `json:"hash"` + HashType string `json:"hash_type"` + Fid int `json:"fid"` + FileName string `json:"file_name"` + ParentID int `json:"parent_id"` + FileSize int `json:"file_size"` + Ext string `json:"ext"` + Thumb string `json:"thumb"` + VipLink int `json:"vip_link"` + } `json:"data"` +} diff --git a/drivers/febbox/util.go b/drivers/febbox/util.go new file mode 100644 index 0000000000000000000000000000000000000000..ac072edbde8956a6a8ff6549c3fae8e1405f08ec --- /dev/null +++ b/drivers/febbox/util.go @@ -0,0 +1,224 @@ +package febbox + +import ( + "encoding/json" + "errors" + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/op" + "github.com/go-resty/resty/v2" + "net/http" + "strconv" +) + +func (d *FebBox) refreshTokenByOAuth2() error { + token, err := d.oauth2Token.Token() + if err != nil { + return err + } + d.Status = "work" + d.accessToken = token.AccessToken + d.Addition.RefreshToken = token.RefreshToken + op.MustSaveDriverStorage(d) + return nil +} + +func (d *FebBox) request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := base.RestyClient.R() + // 使用oauth2 获取 access_token + token, err := d.oauth2Token.Token() + if err != nil { + return nil, err + } + req.SetAuthScheme(token.TokenType).SetAuthToken(token.AccessToken) + + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e ErrResp + req.SetError(&e) + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + + switch e.ErrorCode { + case 0: + return res.Body(), nil + case 1: + return res.Body(), nil + case -10001: + if e.ServerName != "" { + // access_token 过期 + if err = d.refreshTokenByOAuth2(); err != nil { + return nil, err + } + return d.request(url, method, callback, resp) + } else { + return nil, errors.New(e.Error()) + } + default: + return nil, errors.New(e.Error()) + } +} + +func (d *FebBox) getFilesList(id string) ([]File, error) { + if d.PageSize <= 0 { + d.PageSize = 100 + } + res, err := d.listWithLimit(id, d.PageSize) + if err != nil { + return nil, err + } + return *res, nil +} + +func (d *FebBox) listWithLimit(dirID string, pageLimit int64) (*[]File, error) { + var files []File + page := int64(1) + for { + result, err := d.getFiles(dirID, page, pageLimit) + if err != nil { + return nil, err + } + files = append(files, *result...) + if int64(len(*result)) < pageLimit { + break + } else { + page++ + } + } + return &files, nil +} + +func (d *FebBox) getFiles(dirID string, page, pageLimit int64) (*[]File, error) { + var fileList FileListResp + queryParams := map[string]string{ + "module": "file_list", + "parent_id": dirID, + "page": strconv.FormatInt(page, 10), + "pagelimit": strconv.FormatInt(pageLimit, 10), + "order": d.Addition.SortRule, + } + + res, err := d.request("https://api.febbox.com/oauth", http.MethodPost, func(req *resty.Request) { + req.SetMultipartFormData(queryParams) + }, &fileList) + if err != nil { + return nil, err + } + + if err = json.Unmarshal(res, &fileList); err != nil { + return nil, err + } + + return &fileList.Data.FileList, nil +} + +func (d *FebBox) getDownloadLink(id string, ip string) (string, error) { + var fileDownloadResp FileDownloadResp + queryParams := map[string]string{ + "module": "file_get_download_url", + "fids[]": id, + "ip": ip, + } + + res, err := d.request("https://api.febbox.com/oauth", http.MethodPost, func(req *resty.Request) { + req.SetMultipartFormData(queryParams) + }, &fileDownloadResp) + if err != nil { + return "", err + } + + if err = json.Unmarshal(res, &fileDownloadResp); err != nil { + return "", err + } + + return fileDownloadResp.Data[0].DownloadURL, nil +} + +func (d *FebBox) makeDir(id string, name string) error { + queryParams := map[string]string{ + "module": "create_dir", + "parent_id": id, + "name": name, + } + + _, err := d.request("https://api.febbox.com/oauth", http.MethodPost, func(req *resty.Request) { + req.SetMultipartFormData(queryParams) + }, nil) + if err != nil { + return err + } + + return nil +} + +func (d *FebBox) move(id string, id2 string) error { + queryParams := map[string]string{ + "module": "file_move", + "fids[]": id, + "to": id2, + } + + _, err := d.request("https://api.febbox.com/oauth", http.MethodPost, func(req *resty.Request) { + req.SetMultipartFormData(queryParams) + }, nil) + if err != nil { + return err + } + + return nil +} + +func (d *FebBox) rename(id string, name string) error { + queryParams := map[string]string{ + "module": "file_rename", + "fid": id, + "name": name, + } + + _, err := d.request("https://api.febbox.com/oauth", http.MethodPost, func(req *resty.Request) { + req.SetMultipartFormData(queryParams) + }, nil) + if err != nil { + return err + } + + return nil +} + +func (d *FebBox) copy(id string, id2 string) error { + queryParams := map[string]string{ + "module": "file_copy", + "fids[]": id, + "to": id2, + } + + _, err := d.request("https://api.febbox.com/oauth", http.MethodPost, func(req *resty.Request) { + req.SetMultipartFormData(queryParams) + }, nil) + if err != nil { + return err + } + + return nil +} + +func (d *FebBox) remove(id string) error { + queryParams := map[string]string{ + "module": "file_delete", + "fids[]": id, + } + + _, err := d.request("https://api.febbox.com/oauth", http.MethodPost, func(req *resty.Request) { + req.SetMultipartFormData(queryParams) + }, nil) + if err != nil { + return err + } + + return nil +} diff --git a/drivers/ftp/driver.go b/drivers/ftp/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..05b9e49a91d428a3a8f5a290606622fe39bb8f21 --- /dev/null +++ b/drivers/ftp/driver.go @@ -0,0 +1,126 @@ +package ftp + +import ( + "context" + stdpath "path" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/jlaffaye/ftp" +) + +type FTP struct { + model.Storage + Addition + conn *ftp.ServerConn +} + +func (d *FTP) Config() driver.Config { + return config +} + +func (d *FTP) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *FTP) Init(ctx context.Context) error { + return d.login() +} + +func (d *FTP) Drop(ctx context.Context) error { + if d.conn != nil { + _ = d.conn.Logout() + } + return nil +} + +func (d *FTP) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if err := d.login(); err != nil { + return nil, err + } + entries, err := d.conn.List(encode(dir.GetPath(), d.Encoding)) + if err != nil { + return nil, err + } + res := make([]model.Obj, 0) + for _, entry := range entries { + if entry.Name == "." || entry.Name == ".." { + continue + } + f := model.Object{ + Name: decode(entry.Name, d.Encoding), + Size: int64(entry.Size), + Modified: entry.Time, + IsFolder: entry.Type == ftp.EntryTypeFolder, + } + res = append(res, &f) + } + return res, nil +} + +func (d *FTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if err := d.login(); err != nil { + return nil, err + } + + r := NewFileReader(d.conn, encode(file.GetPath(), d.Encoding), file.GetSize()) + link := &model.Link{ + MFile: r, + } + return link, nil +} + +func (d *FTP) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + if err := d.login(); err != nil { + return err + } + return d.conn.MakeDir(encode(stdpath.Join(parentDir.GetPath(), dirName), d.Encoding)) +} + +func (d *FTP) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + if err := d.login(); err != nil { + return err + } + return d.conn.Rename( + encode(srcObj.GetPath(), d.Encoding), + encode(stdpath.Join(dstDir.GetPath(), srcObj.GetName()), d.Encoding), + ) +} + +func (d *FTP) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + if err := d.login(); err != nil { + return err + } + return d.conn.Rename( + encode(srcObj.GetPath(), d.Encoding), + encode(stdpath.Join(stdpath.Dir(srcObj.GetPath()), newName), d.Encoding), + ) +} + +func (d *FTP) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotSupport +} + +func (d *FTP) Remove(ctx context.Context, obj model.Obj) error { + if err := d.login(); err != nil { + return err + } + path := encode(obj.GetPath(), d.Encoding) + if obj.IsDir() { + return d.conn.RemoveDirRecur(path) + } else { + return d.conn.Delete(path) + } +} + +func (d *FTP) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + if err := d.login(); err != nil { + return err + } + // TODO: support cancel + path := stdpath.Join(dstDir.GetPath(), stream.GetName()) + return d.conn.Stor(encode(path, d.Encoding), stream) +} + +var _ driver.Driver = (*FTP)(nil) diff --git a/drivers/ftp/meta.go b/drivers/ftp/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..5652c12e18450a9eba15199f4b165a9bd5eee60b --- /dev/null +++ b/drivers/ftp/meta.go @@ -0,0 +1,44 @@ +package ftp + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" + "github.com/axgle/mahonia" +) + +func encode(str string, encoding string) string { + if encoding == "" { + return str + } + encoder := mahonia.NewEncoder(encoding) + return encoder.ConvertString(str) +} + +func decode(str string, encoding string) string { + if encoding == "" { + return str + } + decoder := mahonia.NewDecoder(encoding) + return decoder.ConvertString(str) +} + +type Addition struct { + Address string `json:"address" required:"true"` + Encoding string `json:"encoding" required:"true"` + Username string `json:"username" required:"true"` + Password string `json:"password" required:"true"` + driver.RootPath +} + +var config = driver.Config{ + Name: "FTP", + LocalSort: true, + OnlyLocal: true, + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &FTP{} + }) +} diff --git a/drivers/ftp/types.go b/drivers/ftp/types.go new file mode 100644 index 0000000000000000000000000000000000000000..4c9820334a7fe42eb28e4358ee51efab18ff3533 --- /dev/null +++ b/drivers/ftp/types.go @@ -0,0 +1 @@ +package ftp diff --git a/drivers/ftp/util.go b/drivers/ftp/util.go new file mode 100644 index 0000000000000000000000000000000000000000..196d874c9bcb949b122838366dcfd3229937f90b --- /dev/null +++ b/drivers/ftp/util.go @@ -0,0 +1,116 @@ +package ftp + +import ( + "io" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/jlaffaye/ftp" +) + +// do others that not defined in Driver interface + +func (d *FTP) login() error { + if d.conn != nil { + _, err := d.conn.CurrentDir() + if err == nil { + return nil + } + } + conn, err := ftp.Dial(d.Address, ftp.DialWithShutTimeout(10*time.Second)) + if err != nil { + return err + } + err = conn.Login(d.Username, d.Password) + if err != nil { + return err + } + d.conn = conn + return nil +} + +// FileReader An FTP file reader that implements io.MFile for seeking. +type FileReader struct { + conn *ftp.ServerConn + resp *ftp.Response + offset atomic.Int64 + readAtOffset int64 + mu sync.Mutex + path string + size int64 +} + +func NewFileReader(conn *ftp.ServerConn, path string, size int64) *FileReader { + return &FileReader{ + conn: conn, + path: path, + size: size, + } +} + +func (r *FileReader) Read(buf []byte) (n int, err error) { + n, err = r.ReadAt(buf, r.offset.Load()) + r.offset.Add(int64(n)) + return +} + +func (r *FileReader) ReadAt(buf []byte, off int64) (n int, err error) { + if off < 0 { + return -1, os.ErrInvalid + } + r.mu.Lock() + defer r.mu.Unlock() + + if off != r.readAtOffset { + //have to restart the connection, to correct offset + _ = r.resp.Close() + r.resp = nil + } + + if r.resp == nil { + r.resp, err = r.conn.RetrFrom(r.path, uint64(off)) + r.readAtOffset = off + if err != nil { + return 0, err + } + } + + n, err = r.resp.Read(buf) + r.readAtOffset += int64(n) + return +} + +func (r *FileReader) Seek(offset int64, whence int) (int64, error) { + oldOffset := r.offset.Load() + var newOffset int64 + switch whence { + case io.SeekStart: + newOffset = offset + case io.SeekCurrent: + newOffset = oldOffset + offset + case io.SeekEnd: + return r.size, nil + default: + return -1, os.ErrInvalid + } + + if newOffset < 0 { + // offset out of range + return oldOffset, os.ErrInvalid + } + if newOffset == oldOffset { + // offset not changed, so return directly + return oldOffset, nil + } + r.offset.Store(newOffset) + return newOffset, nil +} + +func (r *FileReader) Close() error { + if r.resp != nil { + return r.resp.Close() + } + return nil +} diff --git a/drivers/google_drive/driver.go b/drivers/google_drive/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..dccdcea902fcf51d31b0b9f9356a769121e37162 --- /dev/null +++ b/drivers/google_drive/driver.go @@ -0,0 +1,169 @@ +package google_drive + +import ( + "context" + "fmt" + "net/http" + "strconv" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +type GoogleDrive struct { + model.Storage + Addition + AccessToken string + ServiceAccountFile int + ServiceAccountFileList []string +} + +func (d *GoogleDrive) Config() driver.Config { + return config +} + +func (d *GoogleDrive) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *GoogleDrive) Init(ctx context.Context) error { + if d.ChunkSize == 0 { + d.ChunkSize = 5 + } + return d.refreshToken() +} + +func (d *GoogleDrive) Drop(ctx context.Context) error { + return nil +} + +func (d *GoogleDrive) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *GoogleDrive) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + url := fmt.Sprintf("https://www.googleapis.com/drive/v3/files/%s?includeItemsFromAllDrives=true&supportsAllDrives=true", file.GetID()) + _, err := d.request(url, http.MethodGet, nil, nil) + if err != nil { + return nil, err + } + link := model.Link{ + URL: url + "&alt=media&acknowledgeAbuse=true", + Header: http.Header{ + "Authorization": []string{"Bearer " + d.AccessToken}, + }, + } + return &link, nil +} + +func (d *GoogleDrive) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + data := base.Json{ + "name": dirName, + "parents": []string{parentDir.GetID()}, + "mimeType": "application/vnd.google-apps.folder", + } + _, err := d.request("https://www.googleapis.com/drive/v3/files", http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *GoogleDrive) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + query := map[string]string{ + "addParents": dstDir.GetID(), + "removeParents": "root", + } + url := "https://www.googleapis.com/drive/v3/files/" + srcObj.GetID() + _, err := d.request(url, http.MethodPatch, func(req *resty.Request) { + req.SetQueryParams(query) + }, nil) + return err +} + +func (d *GoogleDrive) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + data := base.Json{ + "name": newName, + } + url := "https://www.googleapis.com/drive/v3/files/" + srcObj.GetID() + _, err := d.request(url, http.MethodPatch, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *GoogleDrive) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotSupport +} + +func (d *GoogleDrive) Remove(ctx context.Context, obj model.Obj) error { + url := "https://www.googleapis.com/drive/v3/files/" + obj.GetID() + _, err := d.request(url, http.MethodDelete, nil, nil) + return err +} + +func (d *GoogleDrive) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + obj := stream.GetExist() + var ( + e Error + url string + data base.Json + res *resty.Response + err error + ) + if obj != nil { + url = fmt.Sprintf("https://www.googleapis.com/upload/drive/v3/files/%s?uploadType=resumable&supportsAllDrives=true", obj.GetID()) + data = base.Json{} + } else { + data = base.Json{ + "name": stream.GetName(), + "parents": []string{dstDir.GetID()}, + } + url = "https://www.googleapis.com/upload/drive/v3/files?uploadType=resumable&supportsAllDrives=true" + } + req := base.NoRedirectClient.R(). + SetHeaders(map[string]string{ + "Authorization": "Bearer " + d.AccessToken, + "X-Upload-Content-Type": stream.GetMimetype(), + "X-Upload-Content-Length": strconv.FormatInt(stream.GetSize(), 10), + }). + SetError(&e).SetBody(data).SetContext(ctx) + if obj != nil { + res, err = req.Patch(url) + } else { + res, err = req.Post(url) + } + if err != nil { + return err + } + if e.Error.Code != 0 { + if e.Error.Code == 401 { + err = d.refreshToken() + if err != nil { + return err + } + return d.Put(ctx, dstDir, stream, up) + } + return fmt.Errorf("%s: %v", e.Error.Message, e.Error.Errors) + } + putUrl := res.Header().Get("location") + if stream.GetSize() < d.ChunkSize*1024*1024 { + _, err = d.request(putUrl, http.MethodPut, func(req *resty.Request) { + req.SetHeader("Content-Length", strconv.FormatInt(stream.GetSize(), 10)).SetBody(stream) + }, nil) + } else { + err = d.chunkUpload(ctx, stream, putUrl) + } + return err +} + +var _ driver.Driver = (*GoogleDrive)(nil) diff --git a/drivers/google_drive/meta.go b/drivers/google_drive/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..b0ed3084c4b47457d1d5f8942b5c54ce662b122b --- /dev/null +++ b/drivers/google_drive/meta.go @@ -0,0 +1,28 @@ +package google_drive + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootID + RefreshToken string `json:"refresh_token" required:"true"` + OrderBy string `json:"order_by" type:"string" help:"such as: folder,name,modifiedTime"` + OrderDirection string `json:"order_direction" type:"select" options:"asc,desc"` + ClientID string `json:"client_id" required:"true" default:"202264815644.apps.googleusercontent.com"` + ClientSecret string `json:"client_secret" required:"true" default:"X4Z3ca8xfWDb1Voo-F9a7ZxJ"` + ChunkSize int64 `json:"chunk_size" type:"number" default:"5" help:"chunk size while uploading (unit: MB)"` +} + +var config = driver.Config{ + Name: "GoogleDrive", + OnlyProxy: true, + DefaultRoot: "root", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &GoogleDrive{} + }) +} diff --git a/drivers/google_drive/types.go b/drivers/google_drive/types.go new file mode 100644 index 0000000000000000000000000000000000000000..075459327d96b5e11d08bfdee36123b5a207e90b --- /dev/null +++ b/drivers/google_drive/types.go @@ -0,0 +1,80 @@ +package google_drive + +import ( + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + log "github.com/sirupsen/logrus" +) + +type TokenError struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +type Files struct { + NextPageToken string `json:"nextPageToken"` + Files []File `json:"files"` +} + +type File struct { + Id string `json:"id"` + Name string `json:"name"` + MimeType string `json:"mimeType"` + ModifiedTime time.Time `json:"modifiedTime"` + CreatedTime time.Time `json:"createdTime"` + Size string `json:"size"` + ThumbnailLink string `json:"thumbnailLink"` + ShortcutDetails struct { + TargetId string `json:"targetId"` + TargetMimeType string `json:"targetMimeType"` + } `json:"shortcutDetails"` + + MD5Checksum string `json:"md5Checksum"` + SHA1Checksum string `json:"sha1Checksum"` + SHA256Checksum string `json:"sha256Checksum"` +} + +func fileToObj(f File) *model.ObjThumb { + log.Debugf("google file: %+v", f) + size, _ := strconv.ParseInt(f.Size, 10, 64) + obj := &model.ObjThumb{ + Object: model.Object{ + ID: f.Id, + Name: f.Name, + Size: size, + Ctime: f.CreatedTime, + Modified: f.ModifiedTime, + IsFolder: f.MimeType == "application/vnd.google-apps.folder", + HashInfo: utils.NewHashInfoByMap(map[*utils.HashType]string{ + utils.MD5: f.MD5Checksum, + utils.SHA1: f.SHA1Checksum, + utils.SHA256: f.SHA256Checksum, + }), + }, + Thumbnail: model.Thumbnail{ + Thumbnail: f.ThumbnailLink, + }, + } + if f.MimeType == "application/vnd.google-apps.shortcut" { + obj.ID = f.ShortcutDetails.TargetId + obj.IsFolder = f.ShortcutDetails.TargetMimeType == "application/vnd.google-apps.folder" + } + return obj +} + +type Error struct { + Error struct { + Errors []struct { + Domain string `json:"domain"` + Reason string `json:"reason"` + Message string `json:"message"` + LocationType string `json:"location_type"` + Location string `json:"location"` + } + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` +} diff --git a/drivers/google_drive/util.go b/drivers/google_drive/util.go new file mode 100644 index 0000000000000000000000000000000000000000..0d3801127a45620af4d16a82ca50e873088d26c4 --- /dev/null +++ b/drivers/google_drive/util.go @@ -0,0 +1,244 @@ +package google_drive + +import ( + "context" + "crypto/x509" + "encoding/pem" + "fmt" + "net/http" + "os" + "regexp" + "strconv" + "time" + + "github.com/alist-org/alist/v3/pkg/http_range" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + "github.com/golang-jwt/jwt/v4" + log "github.com/sirupsen/logrus" +) + +// do others that not defined in Driver interface + +type googleDriveServiceAccount struct { + //Type string `json:"type"` + //ProjectID string `json:"project_id"` + //PrivateKeyID string `json:"private_key_id"` + PrivateKey string `json:"private_key"` + ClientEMail string `json:"client_email"` + //ClientID string `json:"client_id"` + //AuthURI string `json:"auth_uri"` + TokenURI string `json:"token_uri"` + //AuthProviderX509CertURL string `json:"auth_provider_x509_cert_url"` + //ClientX509CertURL string `json:"client_x509_cert_url"` +} + +func (d *GoogleDrive) refreshToken() error { + // googleDriveServiceAccountFile gdsaFile + gdsaFile, gdsaFileErr := os.Stat(d.RefreshToken) + if gdsaFileErr == nil { + gdsaFileThis := d.RefreshToken + if gdsaFile.IsDir() { + if len(d.ServiceAccountFileList) <= 0 { + gdsaReadDir, gdsaDirErr := os.ReadDir(d.RefreshToken) + if gdsaDirErr != nil { + log.Error("read dir fail") + return gdsaDirErr + } + var gdsaFileList []string + for _, fi := range gdsaReadDir { + if !fi.IsDir() { + match, _ := regexp.MatchString("^.*\\.json$", fi.Name()) + if !match { + continue + } + gdsaDirText := d.RefreshToken + if d.RefreshToken[len(d.RefreshToken)-1:] != "/" { + gdsaDirText = d.RefreshToken + "/" + } + gdsaFileList = append(gdsaFileList, gdsaDirText+fi.Name()) + } + } + d.ServiceAccountFileList = gdsaFileList + gdsaFileThis = d.ServiceAccountFileList[d.ServiceAccountFile] + d.ServiceAccountFile++ + } else { + if d.ServiceAccountFile < len(d.ServiceAccountFileList) { + d.ServiceAccountFile++ + } else { + d.ServiceAccountFile = 0 + } + gdsaFileThis = d.ServiceAccountFileList[d.ServiceAccountFile] + } + } + + gdsaFileThisContent, err := os.ReadFile(gdsaFileThis) + if err != nil { + return err + } + + // Now let's unmarshal the data into `payload` + var jsonData googleDriveServiceAccount + err = utils.Json.Unmarshal(gdsaFileThisContent, &jsonData) + if err != nil { + return err + } + + gdsaScope := "https://www.googleapis.com/auth/drive https://www.googleapis.com/auth/drive.appdata https://www.googleapis.com/auth/drive.file https://www.googleapis.com/auth/drive.metadata https://www.googleapis.com/auth/drive.metadata.readonly https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/drive.scripts" + + timeNow := time.Now() + var timeStart int64 = timeNow.Unix() + var timeEnd int64 = timeNow.Add(time.Minute * 60).Unix() + + // load private key from string + privateKeyPem, _ := pem.Decode([]byte(jsonData.PrivateKey)) + privateKey, _ := x509.ParsePKCS8PrivateKey(privateKeyPem.Bytes) + + jwtToken := jwt.NewWithClaims(jwt.SigningMethodRS256, + jwt.MapClaims{ + "iss": jsonData.ClientEMail, + "scope": gdsaScope, + "aud": jsonData.TokenURI, + "exp": timeEnd, + "iat": timeStart, + }) + assertion, err := jwtToken.SignedString(privateKey) + if err != nil { + return err + } + + var resp base.TokenResp + var e TokenError + res, err := base.RestyClient.R().SetResult(&resp).SetError(&e). + SetFormData(map[string]string{ + "assertion": assertion, + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + }).Post(jsonData.TokenURI) + if err != nil { + return err + } + log.Debug(res.String()) + if e.Error != "" { + return fmt.Errorf(e.Error) + } + d.AccessToken = resp.AccessToken + return nil + } + if gdsaFileErr != nil && os.IsExist(gdsaFileErr) { + return gdsaFileErr + } + url := "https://www.googleapis.com/oauth2/v4/token" + var resp base.TokenResp + var e TokenError + res, err := base.RestyClient.R().SetResult(&resp).SetError(&e). + SetFormData(map[string]string{ + "client_id": d.ClientID, + "client_secret": d.ClientSecret, + "refresh_token": d.RefreshToken, + "grant_type": "refresh_token", + }).Post(url) + if err != nil { + return err + } + log.Debug(res.String()) + if e.Error != "" { + return fmt.Errorf(e.Error) + } + d.AccessToken = resp.AccessToken + return nil +} + +func (d *GoogleDrive) request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeader("Authorization", "Bearer "+d.AccessToken) + req.SetQueryParam("includeItemsFromAllDrives", "true") + req.SetQueryParam("supportsAllDrives", "true") + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e Error + req.SetError(&e) + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + if e.Error.Code != 0 { + if e.Error.Code == 401 { + err = d.refreshToken() + if err != nil { + return nil, err + } + return d.request(url, method, callback, resp) + } + return nil, fmt.Errorf("%s: %v", e.Error.Message, e.Error.Errors) + } + return res.Body(), nil +} + +func (d *GoogleDrive) getFiles(id string) ([]File, error) { + pageToken := "first" + res := make([]File, 0) + for pageToken != "" { + if pageToken == "first" { + pageToken = "" + } + var resp Files + orderBy := "folder,name,modifiedTime desc" + if d.OrderBy != "" { + orderBy = d.OrderBy + " " + d.OrderDirection + } + query := map[string]string{ + "orderBy": orderBy, + "fields": "files(id,name,mimeType,size,modifiedTime,createdTime,thumbnailLink,shortcutDetails,md5Checksum,sha1Checksum,sha256Checksum),nextPageToken", + "pageSize": "1000", + "q": fmt.Sprintf("'%s' in parents and trashed = false", id), + //"includeItemsFromAllDrives": "true", + //"supportsAllDrives": "true", + "pageToken": pageToken, + } + _, err := d.request("https://www.googleapis.com/drive/v3/files", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return nil, err + } + pageToken = resp.NextPageToken + res = append(res, resp.Files...) + } + return res, nil +} + +func (d *GoogleDrive) chunkUpload(ctx context.Context, stream model.FileStreamer, url string) error { + var defaultChunkSize = d.ChunkSize * 1024 * 1024 + var offset int64 = 0 + for offset < stream.GetSize() { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + chunkSize := stream.GetSize() - offset + if chunkSize > defaultChunkSize { + chunkSize = defaultChunkSize + } + reader, err := stream.RangeRead(http_range.Range{Start: offset, Length: chunkSize}) + if err != nil { + return err + } + _, err = d.request(url, http.MethodPut, func(req *resty.Request) { + req.SetHeaders(map[string]string{ + "Content-Length": strconv.FormatInt(chunkSize, 10), + "Content-Range": fmt.Sprintf("bytes %d-%d/%d", offset, offset+chunkSize-1, stream.GetSize()), + }).SetBody(reader).SetContext(ctx) + }, nil) + if err != nil { + return err + } + offset += chunkSize + } + return nil +} diff --git a/drivers/google_photo/driver.go b/drivers/google_photo/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..b54132ef9edc0bb934f4a4a05032803ab5349982 --- /dev/null +++ b/drivers/google_photo/driver.go @@ -0,0 +1,160 @@ +package google_photo + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +type GooglePhoto struct { + model.Storage + Addition + AccessToken string +} + +func (d *GooglePhoto) Config() driver.Config { + return config +} + +func (d *GooglePhoto) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *GooglePhoto) Init(ctx context.Context) error { + return d.refreshToken() +} + +func (d *GooglePhoto) Drop(ctx context.Context) error { + return nil +} + +func (d *GooglePhoto) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src MediaItem) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *GooglePhoto) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + f, err := d.getMedia(file.GetID()) + if err != nil { + return nil, err + } + + if strings.Contains(f.MimeType, "image/") { + return &model.Link{ + URL: f.BaseURL + "=d", + }, nil + } else if strings.Contains(f.MimeType, "video/") { + return &model.Link{ + URL: f.BaseURL + "=dv", + }, nil + } + return &model.Link{}, nil +} + +func (d *GooglePhoto) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + return errs.NotSupport +} + +func (d *GooglePhoto) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotSupport +} + +func (d *GooglePhoto) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + return errs.NotSupport +} + +func (d *GooglePhoto) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotSupport +} + +func (d *GooglePhoto) Remove(ctx context.Context, obj model.Obj) error { + return errs.NotSupport +} + +func (d *GooglePhoto) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + var e Error + // Create resumable upload url + postHeaders := map[string]string{ + "Authorization": "Bearer " + d.AccessToken, + "Content-type": "application/octet-stream", + "X-Goog-Upload-Command": "start", + "X-Goog-Upload-Content-Type": stream.GetMimetype(), + "X-Goog-Upload-Protocol": "resumable", + "X-Goog-Upload-Raw-Size": strconv.FormatInt(stream.GetSize(), 10), + } + url := "https://photoslibrary.googleapis.com/v1/uploads" + res, err := base.NoRedirectClient.R().SetHeaders(postHeaders). + SetError(&e). + Post(url) + + if err != nil { + return err + } + if e.Error.Code != 0 { + if e.Error.Code == 401 { + err = d.refreshToken() + if err != nil { + return err + } + return d.Put(ctx, dstDir, stream, up) + } + return fmt.Errorf("%s: %v", e.Error.Message, e.Error.Errors) + } + + //Upload to the Google Photo + postUrl := res.Header().Get("X-Goog-Upload-URL") + //chunkSize := res.Header().Get("X-Goog-Upload-Chunk-Granularity") + postHeaders = map[string]string{ + "X-Goog-Upload-Command": "upload, finalize", + "X-Goog-Upload-Offset": "0", + } + + resp, err := d.request(postUrl, http.MethodPost, func(req *resty.Request) { + req.SetBody(stream).SetContext(ctx) + }, nil, postHeaders) + + if err != nil { + return err + } + //Create MediaItem + createItemUrl := "https://photoslibrary.googleapis.com/v1/mediaItems:batchCreate" + + postHeaders = map[string]string{ + "X-Goog-Upload-Command": "upload, finalize", + "X-Goog-Upload-Offset": "0", + } + + data := base.Json{ + "newMediaItems": []base.Json{ + { + "description": "item-description", + "simpleMediaItem": base.Json{ + "fileName": stream.GetName(), + "uploadToken": string(resp), + }, + }, + }, + } + + _, err = d.request(createItemUrl, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil, postHeaders) + + return err +} + +var _ driver.Driver = (*GooglePhoto)(nil) diff --git a/drivers/google_photo/meta.go b/drivers/google_photo/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..cc5f127234f2784ddacd88a9cb018ab0bbb0ae34 --- /dev/null +++ b/drivers/google_photo/meta.go @@ -0,0 +1,28 @@ +package google_photo + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootID + RefreshToken string `json:"refresh_token" required:"true"` + ClientID string `json:"client_id" required:"true" default:"202264815644.apps.googleusercontent.com"` + ClientSecret string `json:"client_secret" required:"true" default:"X4Z3ca8xfWDb1Voo-F9a7ZxJ"` + ShowArchive bool `json:"show_archive"` +} + +var config = driver.Config{ + Name: "GooglePhoto", + OnlyProxy: true, + DefaultRoot: "root", + NoUpload: true, + LocalSort: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &GooglePhoto{} + }) +} diff --git a/drivers/google_photo/types.go b/drivers/google_photo/types.go new file mode 100644 index 0000000000000000000000000000000000000000..1a53ae09bfc56466e567413ca251a5d771d17462 --- /dev/null +++ b/drivers/google_photo/types.go @@ -0,0 +1,85 @@ +package google_photo + +import ( + "reflect" + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type TokenError struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +type Items struct { + NextPageToken string `json:"nextPageToken"` + MediaItems []MediaItem `json:"mediaItems,omitempty"` + Albums []MediaItem `json:"albums,omitempty"` + SharedAlbums []MediaItem `json:"sharedAlbums,omitempty"` +} + +type MediaItem struct { + Id string `json:"id"` + Title string `json:"title,omitempty"` + BaseURL string `json:"baseUrl,omitempty"` + CoverPhotoBaseUrl string `json:"coverPhotoBaseUrl,omitempty"` + MimeType string `json:"mimeType,omitempty"` + FileName string `json:"filename,omitempty"` + MediaMetadata MediaMetadata `json:"mediaMetadata,omitempty"` +} + +type MediaMetadata struct { + CreationTime time.Time `json:"creationTime"` + Width string `json:"width"` + Height string `json:"height"` + Photo Photo `json:"photo,omitempty"` + Video Video `json:"video,omitempty"` +} + +type Photo struct { +} + +type Video struct { +} + +func fileToObj(f MediaItem) *model.ObjThumb { + if !reflect.DeepEqual(f.MediaMetadata, MediaMetadata{}){ + return &model.ObjThumb{ + Object: model.Object{ + ID: f.Id, + Name: f.FileName, + Size: 0, + Modified: f.MediaMetadata.CreationTime, + IsFolder: false, + }, + Thumbnail: model.Thumbnail{ + Thumbnail: f.BaseURL + "=w100-h100-c", + }, + } + } + return &model.ObjThumb{ + Object: model.Object{ + ID: f.Id, + Name: f.Title, + Size: 0, + Modified: time.Time{}, + IsFolder: true, + }, + Thumbnail: model.Thumbnail{}, + } +} + +type Error struct { + Error struct { + Errors []struct { + Domain string `json:"domain"` + Reason string `json:"reason"` + Message string `json:"message"` + LocationType string `json:"location_type"` + Location string `json:"location"` + } + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` +} diff --git a/drivers/google_photo/util.go b/drivers/google_photo/util.go new file mode 100644 index 0000000000000000000000000000000000000000..0fd271b9bb4763283ce752f71fc2be01b570f065 --- /dev/null +++ b/drivers/google_photo/util.go @@ -0,0 +1,186 @@ +package google_photo + +import ( + "fmt" + "net/http" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/go-resty/resty/v2" +) + +// do others that not defined in Driver interface + +const ( + FETCH_ALL = "all" + FETCH_ALBUMS = "albums" + FETCH_ROOT = "root" + FETCH_SHARE_ALBUMS = "share_albums" +) + +func (d *GooglePhoto) refreshToken() error { + url := "https://www.googleapis.com/oauth2/v4/token" + var resp base.TokenResp + var e TokenError + _, err := base.RestyClient.R().SetResult(&resp).SetError(&e). + SetFormData(map[string]string{ + "client_id": d.ClientID, + "client_secret": d.ClientSecret, + "refresh_token": d.RefreshToken, + "grant_type": "refresh_token", + }).Post(url) + if err != nil { + return err + } + if e.Error != "" { + return fmt.Errorf(e.Error) + } + d.AccessToken = resp.AccessToken + return nil +} + +func (d *GooglePhoto) request(url string, method string, callback base.ReqCallback, resp interface{}, headers map[string]string) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeader("Authorization", "Bearer "+d.AccessToken) + req.SetHeader("Accept-Encoding", "gzip") + if headers != nil { + req.SetHeaders(headers) + } + + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e Error + req.SetError(&e) + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + if e.Error.Code != 0 { + if e.Error.Code == 401 { + err = d.refreshToken() + if err != nil { + return nil, err + } + return d.request(url, method, callback, resp, headers) + } + return nil, fmt.Errorf("%s: %v", e.Error.Message, e.Error.Errors) + } + return res.Body(), nil +} + +func (d *GooglePhoto) getFiles(id string) ([]MediaItem, error) { + switch id { + case FETCH_ALL: + return d.getAllMedias() + case FETCH_ALBUMS: + return d.getAlbums() + case FETCH_SHARE_ALBUMS: + return d.getShareAlbums() + case FETCH_ROOT: + return d.getFakeRoot() + default: + return d.getMedias(id) + } +} + +func (d *GooglePhoto) getFakeRoot() ([]MediaItem, error) { + return []MediaItem{ + { + Id: FETCH_ALL, + Title: "全部媒体", + }, + { + Id: FETCH_ALBUMS, + Title: "全部影集", + }, + { + Id: FETCH_SHARE_ALBUMS, + Title: "共享影集", + }, + }, nil +} + +func (d *GooglePhoto) getAlbums() ([]MediaItem, error) { + return d.fetchItems( + "https://photoslibrary.googleapis.com/v1/albums", + map[string]string{ + "fields": "albums(id,title,coverPhotoBaseUrl),nextPageToken", + "pageSize": "50", + "pageToken": "first", + }, + http.MethodGet) +} + +func (d *GooglePhoto) getShareAlbums() ([]MediaItem, error) { + return d.fetchItems( + "https://photoslibrary.googleapis.com/v1/sharedAlbums", + map[string]string{ + "fields": "sharedAlbums(id,title,coverPhotoBaseUrl),nextPageToken", + "pageSize": "50", + "pageToken": "first", + }, + http.MethodGet) +} + +func (d *GooglePhoto) getMedias(albumId string) ([]MediaItem, error) { + return d.fetchItems( + "https://photoslibrary.googleapis.com/v1/mediaItems:search", + map[string]string{ + "fields": "mediaItems(id,baseUrl,mimeType,mediaMetadata,filename),nextPageToken", + "pageSize": "100", + "albumId": albumId, + "pageToken": "first", + }, http.MethodPost) +} + +func (d *GooglePhoto) getAllMedias() ([]MediaItem, error) { + return d.fetchItems( + "https://photoslibrary.googleapis.com/v1/mediaItems", + map[string]string{ + "fields": "mediaItems(id,baseUrl,mimeType,mediaMetadata,filename),nextPageToken", + "pageSize": "100", + "pageToken": "first", + }, + http.MethodGet) +} + +func (d *GooglePhoto) getMedia(id string) (MediaItem, error) { + var resp MediaItem + + query := map[string]string{ + "fields": "mediaMetadata,baseUrl,mimeType", + } + _, err := d.request(fmt.Sprintf("https://photoslibrary.googleapis.com/v1/mediaItems/%s", id), http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp, nil) + if err != nil { + return resp, err + } + + return resp, nil +} + +func (d *GooglePhoto) fetchItems(url string, query map[string]string, method string) ([]MediaItem, error){ + res := make([]MediaItem, 0) + for query["pageToken"] != "" { + if query["pageToken"] == "first" { + query["pageToken"] = "" + } + var resp Items + + _, err := d.request(url, method, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp, nil) + if err != nil { + return nil, err + } + query["pageToken"] = resp.NextPageToken + res = append(res, resp.MediaItems...) + res = append(res, resp.Albums...) + res = append(res, resp.SharedAlbums...) + } + return res, nil +} diff --git a/drivers/halalcloud/driver.go b/drivers/halalcloud/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..08bb3808bfd270f25584e41983d23bcfe9966d4e --- /dev/null +++ b/drivers/halalcloud/driver.go @@ -0,0 +1,406 @@ +package halalcloud + +import ( + "context" + "crypto/sha1" + "fmt" + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/city404/v6-public-rpc-proto/go/v6/common" + pbPublicUser "github.com/city404/v6-public-rpc-proto/go/v6/user" + pubUserFile "github.com/city404/v6-public-rpc-proto/go/v6/userfile" + "github.com/rclone/rclone/lib/readers" + "github.com/zzzhr1990/go-common-entity/userfile" + "io" + "net/url" + "path" + "strconv" + "time" +) + +type HalalCloud struct { + *HalalCommon + model.Storage + Addition + + uploadThread int +} + +func (d *HalalCloud) Config() driver.Config { + return config +} + +func (d *HalalCloud) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *HalalCloud) Init(ctx context.Context) error { + d.uploadThread, _ = strconv.Atoi(d.UploadThread) + if d.uploadThread < 1 || d.uploadThread > 32 { + d.uploadThread, d.UploadThread = 3, "3" + } + + if d.HalalCommon == nil { + d.HalalCommon = &HalalCommon{ + Common: &Common{}, + AuthService: &AuthService{ + appID: func() string { + if d.Addition.AppID != "" { + return d.Addition.AppID + } + return AppID + }(), + appVersion: func() string { + if d.Addition.AppVersion != "" { + return d.Addition.AppVersion + } + return AppVersion + }(), + appSecret: func() string { + if d.Addition.AppSecret != "" { + return d.Addition.AppSecret + } + return AppSecret + }(), + tr: &TokenResp{ + RefreshToken: d.Addition.RefreshToken, + }, + }, + UserInfo: &UserInfo{}, + refreshTokenFunc: func(token string) error { + d.Addition.RefreshToken = token + op.MustSaveDriverStorage(d) + return nil + }, + } + } + + // 防止重复登录 + if d.Addition.RefreshToken == "" || !d.IsLogin() { + as, err := d.NewAuthServiceWithOauth() + if err != nil { + d.GetStorage().SetStatus(fmt.Sprintf("%+v", err.Error())) + return err + } + d.HalalCommon.AuthService = as + d.SetTokenResp(as.tr) + op.MustSaveDriverStorage(d) + } + var err error + d.HalalCommon.serv, err = d.NewAuthService(d.Addition.RefreshToken) + if err != nil { + return err + } + + return nil +} + +func (d *HalalCloud) Drop(ctx context.Context) error { + return nil +} + +func (d *HalalCloud) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + return d.getFiles(ctx, dir) +} + +func (d *HalalCloud) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + return d.getLink(ctx, file, args) +} + +func (d *HalalCloud) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + return d.makeDir(ctx, parentDir, dirName) +} + +func (d *HalalCloud) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + return d.move(ctx, srcObj, dstDir) +} + +func (d *HalalCloud) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + return d.rename(ctx, srcObj, newName) +} + +func (d *HalalCloud) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + return d.copy(ctx, srcObj, dstDir) +} + +func (d *HalalCloud) Remove(ctx context.Context, obj model.Obj) error { + return d.remove(ctx, obj) +} + +func (d *HalalCloud) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + return d.put(ctx, dstDir, stream, up) +} + +func (d *HalalCloud) IsLogin() bool { + if d.AuthService.tr == nil { + return false + } + serv, err := d.NewAuthService(d.Addition.RefreshToken) + if err != nil { + return false + } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + result, err := pbPublicUser.NewPubUserClient(serv.GetGrpcConnection()).Get(ctx, &pbPublicUser.User{ + Identity: "", + }) + if result == nil || err != nil { + return false + } + d.UserInfo.Identity = result.Identity + d.UserInfo.CreateTs = result.CreateTs + d.UserInfo.Name = result.Name + d.UserInfo.UpdateTs = result.UpdateTs + return true +} + +type HalalCommon struct { + *Common + *AuthService // 登录信息 + *UserInfo // 用户信息 + refreshTokenFunc func(token string) error + serv *AuthService +} + +func (d *HalalCloud) SetTokenResp(tr *TokenResp) { + d.Addition.RefreshToken = tr.RefreshToken +} + +func (d *HalalCloud) getFiles(ctx context.Context, dir model.Obj) ([]model.Obj, error) { + + files := make([]model.Obj, 0) + limit := int64(100) + token := "" + client := pubUserFile.NewPubUserFileClient(d.HalalCommon.serv.GetGrpcConnection()) + + opDir := d.GetCurrentDir(dir) + + for { + result, err := client.List(ctx, &pubUserFile.FileListRequest{ + Parent: &pubUserFile.File{Path: opDir}, + ListInfo: &common.ScanListRequest{ + Limit: limit, + Token: token, + }, + }) + if err != nil { + return nil, err + } + + for i := 0; len(result.Files) > i; i++ { + files = append(files, (*Files)(result.Files[i])) + } + + if result.ListInfo == nil || result.ListInfo.Token == "" { + break + } + token = result.ListInfo.Token + + } + return files, nil +} + +func (d *HalalCloud) getLink(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + + client := pubUserFile.NewPubUserFileClient(d.HalalCommon.serv.GetGrpcConnection()) + ctx1, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + result, err := client.ParseFileSlice(ctx1, (*pubUserFile.File)(file.(*Files))) + if err != nil { + return nil, err + } + fileAddrs := []*pubUserFile.SliceDownloadInfo{} + var addressDuration int64 + + nodesNumber := len(result.RawNodes) + nodesIndex := nodesNumber - 1 + startIndex, endIndex := 0, nodesIndex + for nodesIndex >= 0 { + if nodesIndex >= 200 { + endIndex = 200 + } else { + endIndex = nodesNumber + } + for ; endIndex <= nodesNumber; endIndex += 200 { + if endIndex == 0 { + endIndex = 1 + } + sliceAddress, err := client.GetSliceDownloadAddress(ctx, &pubUserFile.SliceDownloadAddressRequest{ + Identity: result.RawNodes[startIndex:endIndex], + Version: 1, + }) + if err != nil { + return nil, err + } + addressDuration = sliceAddress.ExpireAt + fileAddrs = append(fileAddrs, sliceAddress.Addresses...) + startIndex = endIndex + nodesIndex -= 200 + } + + } + + size := result.FileSize + chunks := getChunkSizes(result.Sizes) + var finalClosers utils.Closers + resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + length := httpRange.Length + if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size { + length = -1 + } + if err != nil { + return nil, fmt.Errorf("open download file failed: %w", err) + } + oo := &openObject{ + ctx: ctx, + d: fileAddrs, + chunk: &[]byte{}, + chunks: &chunks, + skip: httpRange.Start, + sha: result.Sha1, + shaTemp: sha1.New(), + } + finalClosers.Add(oo) + + return readers.NewLimitedReadCloser(oo, length), nil + } + + var duration time.Duration + if addressDuration != 0 { + duration = time.Until(time.UnixMilli(addressDuration)) + } else { + duration = time.Until(time.Now().Add(time.Hour)) + } + + resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closers: finalClosers} + return &model.Link{ + RangeReadCloser: resultRangeReadCloser, + Expiration: &duration, + }, nil +} + +func (d *HalalCloud) makeDir(ctx context.Context, dir model.Obj, name string) (model.Obj, error) { + newDir := userfile.NewFormattedPath(d.GetCurrentOpDir(dir, []string{name}, 0)).GetPath() + _, err := pubUserFile.NewPubUserFileClient(d.HalalCommon.serv.GetGrpcConnection()).Create(ctx, &pubUserFile.File{ + Path: newDir, + }) + return nil, err +} + +func (d *HalalCloud) move(ctx context.Context, obj model.Obj, dir model.Obj) (model.Obj, error) { + oldDir := userfile.NewFormattedPath(d.GetCurrentDir(obj)).GetPath() + newDir := userfile.NewFormattedPath(d.GetCurrentDir(dir)).GetPath() + _, err := pubUserFile.NewPubUserFileClient(d.HalalCommon.serv.GetGrpcConnection()).Move(ctx, &pubUserFile.BatchOperationRequest{ + Source: []*pubUserFile.File{ + { + Identity: obj.GetID(), + Path: oldDir, + }, + }, + Dest: &pubUserFile.File{ + Identity: dir.GetID(), + Path: newDir, + }, + }) + return nil, err +} + +func (d *HalalCloud) rename(ctx context.Context, obj model.Obj, name string) (model.Obj, error) { + id := obj.GetID() + newPath := userfile.NewFormattedPath(d.GetCurrentOpDir(obj, []string{name}, 0)).GetPath() + + _, err := pubUserFile.NewPubUserFileClient(d.HalalCommon.serv.GetGrpcConnection()).Rename(ctx, &pubUserFile.File{ + Path: newPath, + Identity: id, + Name: name, + }) + return nil, err +} + +func (d *HalalCloud) copy(ctx context.Context, obj model.Obj, dir model.Obj) (model.Obj, error) { + id := obj.GetID() + sourcePath := userfile.NewFormattedPath(d.GetCurrentDir(obj)).GetPath() + if len(id) > 0 { + sourcePath = "" + } + dest := &pubUserFile.File{ + Identity: dir.GetID(), + Path: userfile.NewFormattedPath(d.GetCurrentDir(dir)).GetPath(), + } + _, err := pubUserFile.NewPubUserFileClient(d.HalalCommon.serv.GetGrpcConnection()).Copy(ctx, &pubUserFile.BatchOperationRequest{ + Source: []*pubUserFile.File{ + { + Path: sourcePath, + Identity: id, + }, + }, + Dest: dest, + }) + return nil, err +} + +func (d *HalalCloud) remove(ctx context.Context, obj model.Obj) error { + id := obj.GetID() + newPath := userfile.NewFormattedPath(d.GetCurrentDir(obj)).GetPath() + //if len(id) > 0 { + // newPath = "" + //} + _, err := pubUserFile.NewPubUserFileClient(d.HalalCommon.serv.GetGrpcConnection()).Delete(ctx, &pubUserFile.BatchOperationRequest{ + Source: []*pubUserFile.File{ + { + Path: newPath, + Identity: id, + }, + }, + }) + return err +} + +func (d *HalalCloud) put(ctx context.Context, dstDir model.Obj, fileStream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + + newDir := path.Join(dstDir.GetPath(), fileStream.GetName()) + + result, err := pubUserFile.NewPubUserFileClient(d.HalalCommon.serv.GetGrpcConnection()).CreateUploadToken(ctx, &pubUserFile.File{ + Path: newDir, + }) + if err != nil { + return nil, err + } + u, _ := url.Parse(result.Endpoint) + u.Host = "s3." + u.Host + result.Endpoint = u.String() + s, err := session.NewSession(&aws.Config{ + HTTPClient: base.HttpClient, + Credentials: credentials.NewStaticCredentials(result.AccessKey, result.SecretKey, result.Token), + Region: aws.String(result.Region), + Endpoint: aws.String(result.Endpoint), + S3ForcePathStyle: aws.Bool(true), + }) + if err != nil { + return nil, err + } + uploader := s3manager.NewUploader(s, func(u *s3manager.Uploader) { + u.Concurrency = d.uploadThread + }) + if fileStream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { + uploader.PartSize = fileStream.GetSize() / (s3manager.MaxUploadParts - 1) + } + _, err = uploader.UploadWithContext(ctx, &s3manager.UploadInput{ + Bucket: aws.String(result.Bucket), + Key: aws.String(result.Key), + Body: io.TeeReader(fileStream, driver.NewProgress(fileStream.GetSize(), up)), + }) + return nil, err + +} + +var _ driver.Driver = (*HalalCloud)(nil) diff --git a/drivers/halalcloud/meta.go b/drivers/halalcloud/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..d4040323eb07768b55ed2213de8f24a3e44669d8 --- /dev/null +++ b/drivers/halalcloud/meta.go @@ -0,0 +1,38 @@ +package halalcloud + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + // Usually one of two + driver.RootPath + // define other + RefreshToken string `json:"refresh_token" required:"true" help:"login type is refresh_token,this is required"` + UploadThread string `json:"upload_thread" default:"3" help:"1 <= thread <= 32"` + + AppID string `json:"app_id" required:"true" default:"alist/10001"` + AppVersion string `json:"app_version" required:"true" default:"1.0.0"` + AppSecret string `json:"app_secret" required:"true" default:"bR4SJwOkvnG5WvVJ"` +} + +var config = driver.Config{ + Name: "HalalCloud", + LocalSort: false, + OnlyLocal: true, + OnlyProxy: true, + NoCache: false, + NoUpload: false, + NeedMs: false, + DefaultRoot: "/", + CheckStatus: false, + Alert: "", + NoOverwriteUpload: false, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &HalalCloud{} + }) +} diff --git a/drivers/halalcloud/options.go b/drivers/halalcloud/options.go new file mode 100644 index 0000000000000000000000000000000000000000..56e5fdc5c096d1e02f8f113e1245d7423835a32a --- /dev/null +++ b/drivers/halalcloud/options.go @@ -0,0 +1,52 @@ +package halalcloud + +import "google.golang.org/grpc" + +func defaultOptions() halalOptions { + return halalOptions{ + // onRefreshTokenRefreshed: func(string) {}, + grpcOptions: []grpc.DialOption{ + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(1024 * 1024 * 32)), + // grpc.WithMaxMsgSize(1024 * 1024 * 1024), + }, + } +} + +type HalalOption interface { + apply(*halalOptions) +} + +// halalOptions configure a RPC call. halalOptions are set by the HalalOption +// values passed to Dial. +type halalOptions struct { + onTokenRefreshed func(accessToken string, accessTokenExpiredAt int64, refreshToken string, refreshTokenExpiredAt int64) + grpcOptions []grpc.DialOption +} + +// funcDialOption wraps a function that modifies halalOptions into an +// implementation of the DialOption interface. +type funcDialOption struct { + f func(*halalOptions) +} + +func (fdo *funcDialOption) apply(do *halalOptions) { + fdo.f(do) +} + +func newFuncDialOption(f func(*halalOptions)) *funcDialOption { + return &funcDialOption{ + f: f, + } +} + +func WithRefreshTokenRefreshedCallback(s func(accessToken string, accessTokenExpiredAt int64, refreshToken string, refreshTokenExpiredAt int64)) HalalOption { + return newFuncDialOption(func(o *halalOptions) { + o.onTokenRefreshed = s + }) +} + +func WithGrpcDialOptions(opts ...grpc.DialOption) HalalOption { + return newFuncDialOption(func(o *halalOptions) { + o.grpcOptions = opts + }) +} diff --git a/drivers/halalcloud/types.go b/drivers/halalcloud/types.go new file mode 100644 index 0000000000000000000000000000000000000000..9772421264bfb6bbe3250d5f86ce19300564a4e9 --- /dev/null +++ b/drivers/halalcloud/types.go @@ -0,0 +1,101 @@ +package halalcloud + +import ( + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/city404/v6-public-rpc-proto/go/v6/common" + pubUserFile "github.com/city404/v6-public-rpc-proto/go/v6/userfile" + "google.golang.org/grpc" + "time" +) + +type AuthService struct { + appID string + appVersion string + appSecret string + grpcConnection *grpc.ClientConn + dopts halalOptions + tr *TokenResp +} + +type TokenResp struct { + AccessToken string `json:"accessToken,omitempty"` + AccessTokenExpiredAt int64 `json:"accessTokenExpiredAt,omitempty"` + RefreshToken string `json:"refreshToken,omitempty"` + RefreshTokenExpiredAt int64 `json:"refreshTokenExpiredAt,omitempty"` +} + +type UserInfo struct { + Identity string `json:"identity,omitempty"` + UpdateTs int64 `json:"updateTs,omitempty"` + Name string `json:"name,omitempty"` + CreateTs int64 `json:"createTs,omitempty"` +} + +type OrderByInfo struct { + Field string `json:"field,omitempty"` + Asc bool `json:"asc,omitempty"` +} + +type ListInfo struct { + Token string `json:"token,omitempty"` + Limit int64 `json:"limit,omitempty"` + OrderBy []*OrderByInfo `json:"order_by,omitempty"` + Version int32 `json:"version,omitempty"` +} + +type FilesList struct { + Files []*Files `json:"files,omitempty"` + ListInfo *common.ScanListRequest `json:"list_info,omitempty"` +} + +var _ model.Obj = (*Files)(nil) + +type Files pubUserFile.File + +func (f *Files) GetSize() int64 { + return f.Size +} + +func (f *Files) GetName() string { + return f.Name +} + +func (f *Files) ModTime() time.Time { + return time.UnixMilli(f.UpdateTs) +} + +func (f *Files) CreateTime() time.Time { + return time.UnixMilli(f.UpdateTs) +} + +func (f *Files) IsDir() bool { + return f.Dir +} + +func (f *Files) GetHash() utils.HashInfo { + return utils.HashInfo{} +} + +func (f *Files) GetID() string { + if len(f.Identity) == 0 { + f.Identity = "/" + } + return f.Identity +} + +func (f *Files) GetPath() string { + return f.Path +} + +type SteamFile struct { + file model.File +} + +func (s *SteamFile) Read(p []byte) (n int, err error) { + return s.file.Read(p) +} + +func (s *SteamFile) Close() error { + return s.file.Close() +} diff --git a/drivers/halalcloud/util.go b/drivers/halalcloud/util.go new file mode 100644 index 0000000000000000000000000000000000000000..f3012a8c83c4589dcc3ec44082c364a04eaf9fd6 --- /dev/null +++ b/drivers/halalcloud/util.go @@ -0,0 +1,385 @@ +package halalcloud + +import ( + "bytes" + "context" + "crypto/md5" + "crypto/tls" + "encoding/hex" + "errors" + "fmt" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + pbPublicUser "github.com/city404/v6-public-rpc-proto/go/v6/user" + pubUserFile "github.com/city404/v6-public-rpc-proto/go/v6/userfile" + "github.com/google/uuid" + "github.com/ipfs/go-cid" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "hash" + "io" + "net/http" + "strconv" + "strings" + "sync" + "time" +) + +const ( + AppID = "alist/10001" + AppVersion = "1.0.0" + AppSecret = "bR4SJwOkvnG5WvVJ" +) + +const ( + grpcServer = "grpcuserapi.2dland.cn:443" + grpcServerAuth = "grpcuserapi.2dland.cn" +) + +func (d *HalalCloud) NewAuthServiceWithOauth(options ...HalalOption) (*AuthService, error) { + + aService := &AuthService{} + err2 := errors.New("") + + svc := d.HalalCommon.AuthService + for _, opt := range options { + opt.apply(&svc.dopts) + } + + grpcOptions := svc.dopts.grpcOptions + grpcOptions = append(grpcOptions, grpc.WithAuthority(grpcServerAuth), grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})), grpc.WithUnaryInterceptor(func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ctxx := svc.signContext(method, ctx) + err := invoker(ctxx, method, req, reply, cc, opts...) // invoking RPC method + return err + })) + + grpcConnection, err := grpc.NewClient(grpcServer, grpcOptions...) + if err != nil { + return nil, err + } + defer grpcConnection.Close() + userClient := pbPublicUser.NewPubUserClient(grpcConnection) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + stateString := uuid.New().String() + // queryValues.Add("callback", oauthToken.Callback) + oauthToken, err := userClient.CreateAuthToken(ctx, &pbPublicUser.LoginRequest{ + ReturnType: 2, + State: stateString, + ReturnUrl: "", + }) + if err != nil { + return nil, err + } + if len(oauthToken.State) < 1 { + oauthToken.State = stateString + } + + if oauthToken.Url != "" { + + return nil, fmt.Errorf(`need verify: Click Here`, oauthToken.Url) + } + + return aService, err2 + +} + +func (d *HalalCloud) NewAuthService(refreshToken string, options ...HalalOption) (*AuthService, error) { + svc := d.HalalCommon.AuthService + + if len(refreshToken) < 1 { + refreshToken = d.Addition.RefreshToken + } + + if len(d.tr.AccessToken) > 0 { + accessTokenExpiredAt := d.tr.AccessTokenExpiredAt + current := time.Now().UnixMilli() + if accessTokenExpiredAt < current { + // access token expired + d.tr.AccessToken = "" + d.tr.AccessTokenExpiredAt = 0 + } else { + svc.tr.AccessTokenExpiredAt = accessTokenExpiredAt + svc.tr.AccessToken = d.tr.AccessToken + } + } + + for _, opt := range options { + opt.apply(&svc.dopts) + } + + grpcOptions := svc.dopts.grpcOptions + grpcOptions = append(grpcOptions, grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(10*1024*1024), grpc.MaxCallRecvMsgSize(10*1024*1024)), grpc.WithAuthority(grpcServerAuth), grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})), grpc.WithUnaryInterceptor(func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ctxx := svc.signContext(method, ctx) + err := invoker(ctxx, method, req, reply, cc, opts...) // invoking RPC method + if err != nil { + grpcStatus, ok := status.FromError(err) + + if ok && grpcStatus.Code() == codes.Unauthenticated && strings.Contains(grpcStatus.Err().Error(), "invalid accesstoken") && len(refreshToken) > 0 { + // refresh token + refreshResponse, err := pbPublicUser.NewPubUserClient(cc).Refresh(ctx, &pbPublicUser.Token{ + RefreshToken: refreshToken, + }) + if err != nil { + return err + } + if len(refreshResponse.AccessToken) > 0 { + svc.tr.AccessToken = refreshResponse.AccessToken + svc.tr.AccessTokenExpiredAt = refreshResponse.AccessTokenExpireTs + svc.OnAccessTokenRefreshed(refreshResponse.AccessToken, refreshResponse.AccessTokenExpireTs, refreshResponse.RefreshToken, refreshResponse.RefreshTokenExpireTs) + } + // retry + ctxx := svc.signContext(method, ctx) + err = invoker(ctxx, method, req, reply, cc, opts...) // invoking RPC method + if err != nil { + return err + } else { + return nil + } + } + } + return err + })) + grpcConnection, err := grpc.NewClient(grpcServer, grpcOptions...) + + if err != nil { + return nil, err + } + + svc.grpcConnection = grpcConnection + return svc, err +} + +func (s *AuthService) OnAccessTokenRefreshed(accessToken string, accessTokenExpiredAt int64, refreshToken string, refreshTokenExpiredAt int64) { + s.tr.AccessToken = accessToken + s.tr.AccessTokenExpiredAt = accessTokenExpiredAt + s.tr.RefreshToken = refreshToken + s.tr.RefreshTokenExpiredAt = refreshTokenExpiredAt + + if s.dopts.onTokenRefreshed != nil { + s.dopts.onTokenRefreshed(accessToken, accessTokenExpiredAt, refreshToken, refreshTokenExpiredAt) + } + +} + +func (s *AuthService) GetGrpcConnection() *grpc.ClientConn { + return s.grpcConnection +} + +func (s *AuthService) Close() { + _ = s.grpcConnection.Close() +} + +func (s *AuthService) signContext(method string, ctx context.Context) context.Context { + var kvString []string + currentTimeStamp := strconv.FormatInt(time.Now().UnixMilli(), 10) + bufferedString := bytes.NewBufferString(method) + kvString = append(kvString, "timestamp", currentTimeStamp) + bufferedString.WriteString(currentTimeStamp) + kvString = append(kvString, "appid", s.appID) + bufferedString.WriteString(s.appID) + kvString = append(kvString, "appversion", s.appVersion) + bufferedString.WriteString(s.appVersion) + if s.tr != nil && len(s.tr.AccessToken) > 0 { + authorization := "Bearer " + s.tr.AccessToken + kvString = append(kvString, "authorization", authorization) + bufferedString.WriteString(authorization) + } + bufferedString.WriteString(s.appSecret) + sign := GetMD5Hash(bufferedString.String()) + kvString = append(kvString, "sign", sign) + return metadata.AppendToOutgoingContext(ctx, kvString...) +} + +func (d *HalalCloud) GetCurrentOpDir(dir model.Obj, args []string, index int) string { + currentDir := dir.GetPath() + if len(currentDir) == 0 { + currentDir = "/" + } + opPath := currentDir + "/" + args[index] + if strings.HasPrefix(args[index], "/") { + opPath = args[index] + } + return opPath +} + +func (d *HalalCloud) GetCurrentDir(dir model.Obj) string { + currentDir := dir.GetPath() + if len(currentDir) == 0 { + currentDir = "/" + } + return currentDir +} + +type Common struct { +} + +func getRawFiles(addr *pubUserFile.SliceDownloadInfo) ([]byte, error) { + + if addr == nil { + return nil, errors.New("addr is nil") + } + + client := http.Client{ + Timeout: time.Duration(60 * time.Second), // Set timeout to 5 seconds + } + resp, err := client.Get(addr.DownloadAddress) + if err != nil { + + return nil, err + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("bad status: %s, body: %s", resp.Status, body) + } + + if addr.Encrypt > 0 { + cd := uint8(addr.Encrypt) + for idx := 0; idx < len(body); idx++ { + body[idx] = body[idx] ^ cd + } + } + + if addr.StoreType != 10 { + + sourceCid, err := cid.Decode(addr.Identity) + if err != nil { + return nil, err + } + checkCid, err := sourceCid.Prefix().Sum(body) + if err != nil { + return nil, err + } + if !checkCid.Equals(sourceCid) { + return nil, fmt.Errorf("bad cid: %s, body: %s", checkCid.String(), body) + } + } + + return body, nil + +} + +type openObject struct { + ctx context.Context + mu sync.Mutex + d []*pubUserFile.SliceDownloadInfo + id int + skip int64 + chunk *[]byte + chunks *[]chunkSize + closed bool + sha string + shaTemp hash.Hash +} + +// get the next chunk +func (oo *openObject) getChunk(ctx context.Context) (err error) { + if oo.id >= len(*oo.chunks) { + return io.EOF + } + var chunk []byte + err = utils.Retry(3, time.Second, func() (err error) { + chunk, err = getRawFiles(oo.d[oo.id]) + return err + }) + if err != nil { + return err + } + oo.id++ + oo.chunk = &chunk + return nil +} + +// Read reads up to len(p) bytes into p. +func (oo *openObject) Read(p []byte) (n int, err error) { + oo.mu.Lock() + defer oo.mu.Unlock() + if oo.closed { + return 0, fmt.Errorf("read on closed file") + } + // Skip data at the start if requested + for oo.skip > 0 { + //size := 1024 * 1024 + _, size, err := oo.ChunkLocation(oo.id) + if err != nil { + return 0, err + } + if oo.skip < int64(size) { + break + } + oo.id++ + oo.skip -= int64(size) + } + if len(*oo.chunk) == 0 { + err = oo.getChunk(oo.ctx) + if err != nil { + return 0, err + } + if oo.skip > 0 { + *oo.chunk = (*oo.chunk)[oo.skip:] + oo.skip = 0 + } + } + n = copy(p, *oo.chunk) + *oo.chunk = (*oo.chunk)[n:] + + oo.shaTemp.Write(*oo.chunk) + + return n, nil +} + +// Close closed the file - MAC errors are reported here +func (oo *openObject) Close() (err error) { + oo.mu.Lock() + defer oo.mu.Unlock() + if oo.closed { + return nil + } + // 校验Sha1 + if string(oo.shaTemp.Sum(nil)) != oo.sha { + return fmt.Errorf("failed to finish download: %w", err) + } + + oo.closed = true + return nil +} + +func GetMD5Hash(text string) string { + tHash := md5.Sum([]byte(text)) + return hex.EncodeToString(tHash[:]) +} + +// chunkSize describes a size and position of chunk +type chunkSize struct { + position int64 + size int +} + +func getChunkSizes(sliceSize []*pubUserFile.SliceSize) (chunks []chunkSize) { + chunks = make([]chunkSize, 0) + for _, s := range sliceSize { + // 对最后一个做特殊处理 + if s.EndIndex == 0 { + s.EndIndex = s.StartIndex + } + for j := s.StartIndex; j <= s.EndIndex; j++ { + chunks = append(chunks, chunkSize{position: j, size: int(s.Size)}) + } + } + return chunks +} + +func (oo *openObject) ChunkLocation(id int) (position int64, size int, err error) { + if id < 0 || id >= len(*oo.chunks) { + return 0, 0, errors.New("invalid arguments") + } + + return (*oo.chunks)[id].position, (*oo.chunks)[id].size, nil +} diff --git a/drivers/homecloud/driver.go b/drivers/homecloud/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..5d4741a60906ae7ff076d1f3d76a9678b5cdc265 --- /dev/null +++ b/drivers/homecloud/driver.go @@ -0,0 +1,359 @@ +package homecloud + +import ( + "bytes" + "context" + "crypto/md5" + "crypto/sha1" + "encoding/hex" + "fmt" + "io" + "mime/multipart" + "net/http" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/cron" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/pkg/utils/random" +) + +type HomeCloud struct { + model.Storage + Addition + AccessToken string + UserID string + cron *cron.Cron + Account string +} + +func (d *HomeCloud) Config() driver.Config { + return config +} + +func (d *HomeCloud) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *HomeCloud) Init(ctx context.Context) error { + if d.RefreshToken == "" { + return fmt.Errorf("RefreshToken is empty") + } + + if len(d.Addition.RootFolderID) == 0 { + d.RootFolderID = "0" + } + + err := d.refreshToken() + if err != nil { + return err + } + + d.cron = cron.NewCron(time.Hour * 10) + d.cron.Do(func() { + err := d.refreshToken() + if err != nil { + return + } + }) + + return nil +} + +func (d *HomeCloud) Drop(ctx context.Context) error { + if d.cron != nil { + d.cron.Stop() + } + return nil +} + +func (d *HomeCloud) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + return d.familyGetFiles(dir.GetID()) +} + +func (d *HomeCloud) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var url string + var err error + url, err = d.getLink(file.GetID()) + if err != nil { + return nil, err + } + + link := &model.Link{ + URL: url, + } + + // 创建Header 否则新上传文件无法使用 + header := make(http.Header) + header.Add("Cookie", "H_TOKEN="+d.AccessToken) + header.Add("User-Agent", "Mozilla/5.0 (iPhone; CPU iPhone OS 18_1_1 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/18.1.1 Mobile/15E148 Safari/604.1") + link.Header = header + + return link, nil +} + +func (d *HomeCloud) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + var err error + data := base.Json{ + "parentDirId": parentDir.GetID(), + "dirName": dirName, + "category": 0, + "userId": d.UserID, + "groupId": d.GroupID, + } + pathname := "/storage/addDirectory/v1" + _, err = d.post(pathname, data, nil) + return err +} + +func (d *HomeCloud) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + data := base.Json{ + "fileIds": []string{srcObj.GetID()}, + "targetDirId": dstDir.GetID(), + "userId": d.UserID, + "groupId": d.GroupID, + } + pathname := "/storage/batchMoveFile/v1" + _, err := d.post(pathname, data, nil) + if err != nil { + return nil, err + } + return srcObj, nil +} + +func (d *HomeCloud) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + var err error + data := base.Json{ + "fileId": srcObj.GetID(), + "fileName": newName, + "userId": d.UserID, + "groupId": d.GroupID, + } + pathname := "/storage/updateFileName/v1" + _, err = d.post(pathname, data, nil) + + return err +} + +func (d *HomeCloud) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + // 复制会占用空间 所以先屏蔽代码 + // data := base.Json{ + // "fileIds": []string{srcObj.GetID()}, + // "targetDirId": dstDir.GetID(), + // "targetGroupId": d.GroupID, + // "userId": d.UserID, + // "groupId": d.GroupID, + // } + // pathname := "/storage/batchCopyFile/v1" + // _, err := d.post(pathname, data, nil) + // return err + return errs.NotImplement +} + +func (d *HomeCloud) Remove(ctx context.Context, obj model.Obj) error { + data := base.Json{ + "fileIds": []string{obj.GetID()}, + "userId": d.UserID, + "groupId": d.GroupID, + } + pathname := "/storage/batchDeleteFile/v1" + if obj.IsDir() { + data = base.Json{ + "fileId": obj.GetID(), + "userId": d.UserID, + "groupId": d.GroupID, + } + pathname = "/storage/deleteDirectory/v1" + } + _, err := d.post(pathname, data, nil) + return err +} + +const ( + _ = iota //ignore first value by assigning to blank identifier + KB = 1 << (10 * iota) + MB + GB + TB +) + +func getPartSize(size int64) int64 { + // 网盘对于分片数量存在上限 + if size/GB > 30 { + return 512 * MB + } + return 100 * MB +} + +func (d *HomeCloud) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + var err error + + h := md5.New() + // need to calculate md5 of the full content + tempFile, err := stream.CacheFullInTempFile() + if err != nil { + return err + } + defer func() { + _ = tempFile.Close() + }() + if _, err = io.Copy(h, tempFile); err != nil { + return err + } + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return err + } + etag := hex.EncodeToString(h.Sum(nil)) + + // return errs.NotImplement + data := base.Json{ + "userId": d.UserID, + "groupId": d.GroupID, + "dirId": dstDir.GetID(), + "fileName": stream.GetName(), + "fileMd5": etag, + "fileSize": stream.GetSize(), + "fileCategory": 99, + } + + pathname := "/storage/addFileUploadTask/v1" + var resp PersonalUploadResp + _, err = d.post(pathname, data, &resp) + if err != nil { + return err + } + + // Progress + p := driver.NewProgress(stream.GetSize(), up) + + var partSize = getPartSize(stream.GetSize()) + part := (stream.GetSize() + partSize - 1) / partSize + if part == 0 { + part = 1 + } + for i := int64(0); i < part; i++ { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + + start := i * partSize + byteSize := stream.GetSize() - start + if byteSize > partSize { + byteSize = partSize + } + + limitReader := io.LimitReader(stream, byteSize) + // Update Progress + r := io.TeeReader(limitReader, p) + // Update Progress + //r := io.TeeReader(limitReader, p) + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + filePart, err := writer.CreateFormFile("partFile", stream.GetName()) + if err != nil { + return err + } + _, err = io.Copy(filePart, r) + if err != nil { + return err + } + + isDone := false + + if i == (part - 1) { + isDone = true + } + + _ = writer.WriteField("uploadId", resp.Data.UploadId) + _ = writer.WriteField("isComplete", strconv.FormatBool(isDone)) + _ = writer.WriteField("rangeStart", strconv.Itoa(int(start))) + + err = writer.Close() + if err != nil { + return err + } + + req, err := http.NewRequest("POST", resp.Data.UploadUrl, body) + if err != nil { + return err + } + requestID := random.String(12) + pbody, err := utils.Json.Marshal(body) + + if err != nil { + return err + } + + timestamp := fmt.Sprintf("%.3f", float64(time.Now().UnixNano())/1e6) + h := sha1.New() + var sha1Hash string + + if pbody == nil { + h.Write([]byte("{}")) + sha1Hash = strings.ToUpper(hex.EncodeToString(h.Sum(nil))) + } else { + h.Write(pbody) + sha1Hash = strings.ToUpper(hex.EncodeToString(h.Sum(nil))) + } + + uppathname := "/upload/upload/uploadFilePart/v1" + encStr := fmt.Sprintf("%s;%s;%s;Bearer %s;%s", uppathname, sha1Hash, requestID, d.AccessToken, timestamp) + signature := strings.ToUpper(fmt.Sprintf("%x", md5.Sum([]byte(encStr)))) + + req = req.WithContext(ctx) + req.Header.Add("Accept", "*/*") + req.Header.Add("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8") + req.Header.Add("Authorization", "Bearer "+d.AccessToken) + req.Header.Add("Origin", "https://homecloud.komect.com") + req.Header.Add("Referer", "https://homecloud.komect.com/disk/main/familyspace") + req.Header.Add("Request-Id", requestID) + req.Header.Add("Signature", signature) + req.Header.Add("User-Agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36") + req.Header.Add("X-Requested-With", "XMLHttpRequest") + req.Header.Add("X-User-Agent", "Web|Chrome 127.0.0.0||OS X|homecloudWebDisk_1.1.1||yunpan 1.1.1|unknown") + req.Header.Add("sec-ch-ua", "\"Not)A;Brand\";v=\"99\", \"Google Chrome\";v=\"127\", \"Chromium\";v=\"127\"") + req.Header.Add("sec-ch-ua-mobile", "?0") + req.Header.Add("sec-ch-ua-platform", "\"macOS\"") + req.Header.Add("userId", d.UserID) + req.Header.Set("Content-Type", writer.FormDataContentType()) + + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + _ = res.Body.Close() + //log.Debugf("%+v", res) + if res.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + } + + // url, err := d.getLink(resp.Data.FileId) + // if err != nil { + // return fmt.Errorf("can not get file donwnload url") + // } + + // _, err = base.RestyClient.R(). + // SetHeader("Cookie", "H_TOKEN="+d.AccessToken). + // SetHeader("Range", "bytes=0-100"). + // Get(url) + // if err != nil { + // return fmt.Errorf("can not active file") + // } + + return nil +} + +func (d *HomeCloud) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { + return nil, errs.NotImplement +} + +var _ driver.Driver = (*HomeCloud)(nil) diff --git a/drivers/homecloud/meta.go b/drivers/homecloud/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..a117c4b9b351cbd6402e0d70750510ba6b41fe22 --- /dev/null +++ b/drivers/homecloud/meta.go @@ -0,0 +1,25 @@ +package homecloud + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +// https://homecloud.komect.com/ +type Addition struct { + //Account string `json:"account" required:"true"` + RefreshToken string `json:"refresh_token" required:"true"` + driver.RootID + GroupID string `json:"groupId" required:"true"` +} + +var config = driver.Config{ + Name: "homecloud", + LocalSort: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &HomeCloud{} + }) +} diff --git a/drivers/homecloud/types.go b/drivers/homecloud/types.go new file mode 100644 index 0000000000000000000000000000000000000000..1efe36da46b2d1391c798b61a11fbe6cb94be021 --- /dev/null +++ b/drivers/homecloud/types.go @@ -0,0 +1,265 @@ +package homecloud + +const ( + MetaPersonal string = "personal" + MetaFamily string = "family" + MetaPersonalNew string = "personal_new" +) + +type BaseResp struct { + Success bool `json:"success"` + Code string `json:"code"` + Message string `json:"message"` +} + +type FrontResp struct { + Ret int `json:"ret"` + Reason string `json:"reason"` +} + +type Catalog struct { + CatalogID string `json:"catalogID"` + CatalogName string `json:"catalogName"` + //CatalogType int `json:"catalogType"` + CreateTime string `json:"createTime"` + UpdateTime string `json:"updateTime"` + //IsShared bool `json:"isShared"` + //CatalogLevel int `json:"catalogLevel"` + //ShareDoneeCount int `json:"shareDoneeCount"` + //OpenType int `json:"openType"` + //ParentCatalogID string `json:"parentCatalogId"` + //DirEtag int `json:"dirEtag"` + //Tombstoned int `json:"tombstoned"` + //ProxyID interface{} `json:"proxyID"` + //Moved int `json:"moved"` + //IsFixedDir int `json:"isFixedDir"` + //IsSynced interface{} `json:"isSynced"` + //Owner string `json:"owner"` + //Modifier interface{} `json:"modifier"` + //Path string `json:"path"` + //ShareType int `json:"shareType"` + //SoftLink interface{} `json:"softLink"` + //ExtProp1 interface{} `json:"extProp1"` + //ExtProp2 interface{} `json:"extProp2"` + //ExtProp3 interface{} `json:"extProp3"` + //ExtProp4 interface{} `json:"extProp4"` + //ExtProp5 interface{} `json:"extProp5"` + //ETagOprType int `json:"ETagOprType"` +} + +type Content struct { + ContentID string `json:"contentID"` + ContentName string `json:"contentName"` + //ContentSuffix string `json:"contentSuffix"` + ContentSize int64 `json:"contentSize"` + //ContentDesc string `json:"contentDesc"` + //ContentType int `json:"contentType"` + //ContentOrigin int `json:"contentOrigin"` + UpdateTime string `json:"updateTime"` + //CommentCount int `json:"commentCount"` + ThumbnailURL string `json:"thumbnailURL"` + //BigthumbnailURL string `json:"bigthumbnailURL"` + //PresentURL string `json:"presentURL"` + //PresentLURL string `json:"presentLURL"` + //PresentHURL string `json:"presentHURL"` + //ContentTAGList interface{} `json:"contentTAGList"` + //ShareDoneeCount int `json:"shareDoneeCount"` + //Safestate int `json:"safestate"` + //Transferstate int `json:"transferstate"` + //IsFocusContent int `json:"isFocusContent"` + //UpdateShareTime interface{} `json:"updateShareTime"` + //UploadTime string `json:"uploadTime"` + //OpenType int `json:"openType"` + //AuditResult int `json:"auditResult"` + //ParentCatalogID string `json:"parentCatalogId"` + //Channel string `json:"channel"` + //GeoLocFlag string `json:"geoLocFlag"` + Digest string `json:"digest"` + //Version string `json:"version"` + //FileEtag string `json:"fileEtag"` + //FileVersion string `json:"fileVersion"` + //Tombstoned int `json:"tombstoned"` + //ProxyID string `json:"proxyID"` + //Moved int `json:"moved"` + //MidthumbnailURL string `json:"midthumbnailURL"` + //Owner string `json:"owner"` + //Modifier string `json:"modifier"` + //ShareType int `json:"shareType"` + //ExtInfo struct { + // Uploader string `json:"uploader"` + // Address string `json:"address"` + //} `json:"extInfo"` + //Exif struct { + // CreateTime string `json:"createTime"` + // Longitude interface{} `json:"longitude"` + // Latitude interface{} `json:"latitude"` + // LocalSaveTime interface{} `json:"localSaveTime"` + //} `json:"exif"` + //CollectionFlag interface{} `json:"collectionFlag"` + //TreeInfo interface{} `json:"treeInfo"` + //IsShared bool `json:"isShared"` + //ETagOprType int `json:"ETagOprType"` +} + +type GetDiskResp struct { + BaseResp + Data struct { + Result struct { + ResultCode string `json:"resultCode"` + ResultDesc interface{} `json:"resultDesc"` + } `json:"result"` + GetDiskResult struct { + ParentCatalogID string `json:"parentCatalogID"` + NodeCount int `json:"nodeCount"` + CatalogList []Catalog `json:"catalogList"` + ContentList []Content `json:"contentList"` + IsCompleted int `json:"isCompleted"` + } `json:"getDiskResult"` + } `json:"data"` +} + +type UploadResp struct { + BaseResp + Data struct { + Result struct { + ResultCode string `json:"resultCode"` + ResultDesc interface{} `json:"resultDesc"` + } `json:"result"` + UploadResult struct { + UploadTaskID string `json:"uploadTaskID"` + RedirectionURL string `json:"redirectionUrl"` + NewContentIDList []struct { + ContentID string `json:"contentID"` + ContentName string `json:"contentName"` + IsNeedUpload string `json:"isNeedUpload"` + FileEtag int64 `json:"fileEtag"` + FileVersion int64 `json:"fileVersion"` + OverridenFlag int `json:"overridenFlag"` + } `json:"newContentIDList"` + CatalogIDList interface{} `json:"catalogIDList"` + IsSlice interface{} `json:"isSlice"` + } `json:"uploadResult"` + } `json:"data"` +} + +type CloudContent struct { + ContentID string `json:"contentID"` + //Modifier string `json:"modifier"` + //Nickname string `json:"nickname"` + //CloudNickName string `json:"cloudNickName"` + ContentName string `json:"contentName"` + //ContentType int `json:"contentType"` + //ContentSuffix string `json:"contentSuffix"` + ContentSize int64 `json:"contentSize"` + //ContentDesc string `json:"contentDesc"` + CreateTime string `json:"createTime"` + //Shottime interface{} `json:"shottime"` + LastUpdateTime string `json:"lastUpdateTime"` + ThumbnailURL string `json:"thumbnailURL"` + //MidthumbnailURL string `json:"midthumbnailURL"` + //BigthumbnailURL string `json:"bigthumbnailURL"` + //PresentURL string `json:"presentURL"` + //PresentLURL string `json:"presentLURL"` + //PresentHURL string `json:"presentHURL"` + //ParentCatalogID string `json:"parentCatalogID"` + //Uploader string `json:"uploader"` + //UploaderNickName string `json:"uploaderNickName"` + //TreeInfo interface{} `json:"treeInfo"` + //UpdateTime interface{} `json:"updateTime"` + //ExtInfo struct { + // Uploader string `json:"uploader"` + //} `json:"extInfo"` + //EtagOprType interface{} `json:"etagOprType"` +} + +type CloudCatalog struct { + CatalogID string `json:"catalogID"` + CatalogName string `json:"catalogName"` + //CloudID string `json:"cloudID"` + CreateTime string `json:"createTime"` + LastUpdateTime string `json:"lastUpdateTime"` + //Creator string `json:"creator"` + //CreatorNickname string `json:"creatorNickname"` +} + +type QueryContentListResp struct { + FrontResp + Data struct { + Ret int `json:"ret"` + Reason string `json:"reason"` + Total string `json:"total"` + FileInfos []FileInfo `json:"fileInfos"` + } `json:"data"` +} + +type FileInfo struct { + ID string `json:"id"` + Name string `json:"name"` + Type int64 `json:"type"` + Category int64 `json:"category"` + Md5 string `json:"md5"` + Size string `json:"size"` + ParentId string `json:"parentId"` + GroupID string `json:"groupId"` + UserID string `json:"userId"` + CreateTime string `json:"createTime"` + UpdateTime string `json:"updateTime"` + Ctag string `json:"ctag"` + ParentCategory int64 `json:"parentCategory"` +} + +type PersonalThumbnail struct { + Style string `json:"style"` + Url string `json:"url"` +} + +type PersonalFileItem struct { + FileId string `json:"fileId"` + Name string `json:"name"` + Size int64 `json:"size"` + Type string `json:"type"` + CreatedAt string `json:"createdAt"` + UpdatedAt string `json:"updatedAt"` + Thumbnails []PersonalThumbnail `json:"thumbnailUrls"` +} + +type PersonalListResp struct { + BaseResp + Data struct { + Items []PersonalFileItem `json:"items"` + NextPageCursor string `json:"nextPageCursor"` + } +} + +type PersonalPartInfo struct { + PartNumber int `json:"partNumber"` + UploadUrl string `json:"uploadUrl"` +} + +type PersonalUploadResp struct { + FrontResp + Data struct { + Ret int `json:"ret"` + Reason string `json:"reason"` + UploadId string `json:"uploadId"` + UploadState int `json:"uploadState"` + ExpireTime string `json:"expireTime"` + FileId string `json:"fileId"` + UploadUrl string `json:"uploadUrl"` + } +} + +type RefreshTokenResp struct { + FrontResp + Data struct { + Ret int `json:"ret"` + ExpiresIn int `json:"expiresIn"` + License string `json:"license"` + Scope string `json:"scope"` + UserType int `json:"userType"` + AccessToken string `json:"accessToken"` + TokenType string `json:"tokenType"` + UserID string `json:"userId"` + RefreshToken string `json:"refreshToken"` + } +} diff --git a/drivers/homecloud/util.go b/drivers/homecloud/util.go new file mode 100644 index 0000000000000000000000000000000000000000..43a3d2bcc2baef2074a47b3a481ee930cab8bc9e --- /dev/null +++ b/drivers/homecloud/util.go @@ -0,0 +1,512 @@ +package homecloud + +import ( + "crypto/md5" + "crypto/sha1" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "net/http" + "net/url" + "sort" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/pkg/utils/random" + "github.com/go-resty/resty/v2" + jsoniter "github.com/json-iterator/go" + log "github.com/sirupsen/logrus" +) + +// do others that not defined in Driver interface +func (d *HomeCloud) isFamily() bool { + return false +} + +func encodeURIComponent(str string) string { + r := url.QueryEscape(str) + r = strings.Replace(r, "+", "%20", -1) + r = strings.Replace(r, "%21", "!", -1) + r = strings.Replace(r, "%27", "'", -1) + r = strings.Replace(r, "%28", "(", -1) + r = strings.Replace(r, "%29", ")", -1) + r = strings.Replace(r, "%2A", "*", -1) + return r +} + +func calSign(body, ts, randStr string) string { + body = encodeURIComponent(body) + strs := strings.Split(body, "") + sort.Strings(strs) + body = strings.Join(strs, "") + body = base64.StdEncoding.EncodeToString([]byte(body)) + res := utils.GetMD5EncodeStr(body) + utils.GetMD5EncodeStr(ts+":"+randStr) + res = strings.ToUpper(utils.GetMD5EncodeStr(res)) + return res +} + +func getTime(t string) time.Time { + stamp, _ := time.ParseInLocation("20060102150405", t, utils.CNLoc) + return stamp +} + +func (d *HomeCloud) refreshToken() error { + pathname := "/auth/refreshToken/v2" + url := "https://homecloud.komect.com/front" + pathname + + data := base.Json{ + "refresh_token": d.RefreshToken, + "scope": "sdk", + } + + requestID := random.String(12) + body, err := utils.Json.Marshal(data) + + if err != nil { + return err + } + + timestamp := fmt.Sprintf("%.3f", float64(time.Now().UnixNano())/1e6) + h := sha1.New() + var sha1Hash string + + if body == nil { + h.Write([]byte("{}")) + sha1Hash = strings.ToUpper(hex.EncodeToString(h.Sum(nil))) + } else { + h.Write(body) + sha1Hash = strings.ToUpper(hex.EncodeToString(h.Sum(nil))) + } + + encStr := fmt.Sprintf("%s;%s;%s;Basic VTdUOU1xSHpVbklqeWdETzppQzZCU25QaExyODZGZmJX;%s", pathname, sha1Hash, requestID, timestamp) + signature := strings.ToUpper(fmt.Sprintf("%x", md5.Sum([]byte(encStr)))) + + var resp RefreshTokenResp + var e FrontResp + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "Accept": "application/json, text/plain, */*", + "Authorization": "Basic VTdUOU1xSHpVbklqeWdETzppQzZCU25QaExyODZGZmJX", + "Content-Type": "application/json", + "X-User-Agent": "Web|Chrome 127.0.0.0||OS X|homecloudWebDisk_1.1.1||yunpan 1.1.1|unknown", + "Timestamp": timestamp, + "Signature": signature, + "Request-Id": requestID, + "userId": "", + }) + req.SetBody(data) + req.SetResult(&resp) + req.SetError(&e) + _, err = req.Post(url) + //fmt.Println(string(res.Body())) + + if err != nil { + return err + } + // if e.Ret != 200 { + // return fmt.Errorf("failed to refresh token: %s", e.Reason) + // } + if resp.Data.RefreshToken == "" { + return errors.New("failed to refresh token: refresh token is empty") + } + d.RefreshToken, d.AccessToken, d.UserID = resp.Data.RefreshToken, resp.Data.AccessToken, resp.Data.UserID + op.MustSaveDriverStorage(d) + return nil +} + +func (d *HomeCloud) request(pathname string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + url := "https://homecloud.komect.com/front" + pathname + req := base.RestyClient.R() + requestID := random.String(12) + //ts := time.Now().Format("2006-01-02 15:04:05") + if callback != nil { + callback(req) + } + body, err := utils.Json.Marshal(req.Body) + + if err != nil { + return nil, err + } + + timestamp := fmt.Sprintf("%.3f", float64(time.Now().UnixNano())/1e6) + h := sha1.New() + var sha1Hash string + + if body == nil { + h.Write([]byte("{}")) + sha1Hash = strings.ToUpper(hex.EncodeToString(h.Sum(nil))) + } else { + h.Write(body) + sha1Hash = strings.ToUpper(hex.EncodeToString(h.Sum(nil))) + } + + encStr := fmt.Sprintf("%s;%s;%s;Bearer %s;%s", pathname, sha1Hash, requestID, d.AccessToken, timestamp) + signature := strings.ToUpper(fmt.Sprintf("%x", md5.Sum([]byte(encStr)))) + + req.SetHeaders(map[string]string{ + "Accept": "application/json, text/plain, */*", + "Authorization": "Bearer " + d.AccessToken, + "Content-Type": "application/json", + "X-User-Agent": "Web|Chrome 127.0.0.0||OS X|homecloudWebDisk_1.1.1||yunpan 1.1.1|unknown", + "Timestamp": timestamp, + "Signature": signature, + "Request-Id": requestID, + "userId": d.UserID, + }) + + var e FrontResp + req.SetResult(&e) + res, err := req.Execute(method, url) + + //log.Debugln(res.String()) + if e.Ret != 200 { + return nil, errors.New(e.Reason) + } + if resp != nil { + err = utils.Json.Unmarshal(res.Body(), resp) + if err != nil { + return nil, err + } + } + //fmt.Println(string(res.Body())) + return res.Body(), nil +} + +func (d *HomeCloud) post(pathname string, data interface{}, resp interface{}) ([]byte, error) { + return d.request(pathname, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, resp) +} + +func (d *HomeCloud) getFiles(catalogID string) ([]model.Obj, error) { + start := 0 + limit := 100 + files := make([]model.Obj, 0) + for { + data := base.Json{ + "catalogID": catalogID, + "sortDirection": 1, + "startNumber": start + 1, + "endNumber": start + limit, + "filterType": 0, + "catalogSortType": 0, + "contentSortType": 0, + "commonAccountInfo": base.Json{ + "account": d.Account, + "accountType": 1, + }, + } + var resp GetDiskResp + _, err := d.post("/orchestration/personalCloud/catalog/v1.0/getDisk", data, &resp) + if err != nil { + return nil, err + } + for _, catalog := range resp.Data.GetDiskResult.CatalogList { + f := model.Object{ + ID: catalog.CatalogID, + Name: catalog.CatalogName, + Size: 0, + Modified: getTime(catalog.UpdateTime), + Ctime: getTime(catalog.CreateTime), + IsFolder: true, + } + files = append(files, &f) + } + for _, content := range resp.Data.GetDiskResult.ContentList { + f := model.ObjThumb{ + Object: model.Object{ + ID: content.ContentID, + Name: content.ContentName, + Size: content.ContentSize, + Modified: getTime(content.UpdateTime), + HashInfo: utils.NewHashInfo(utils.MD5, content.Digest), + }, + Thumbnail: model.Thumbnail{Thumbnail: content.ThumbnailURL}, + //Thumbnail: content.BigthumbnailURL, + } + files = append(files, &f) + } + if start+limit >= resp.Data.GetDiskResult.NodeCount { + break + } + start += limit + } + return files, nil +} + +func (d *HomeCloud) newJson(data map[string]interface{}) base.Json { + common := map[string]interface{}{} + return utils.MergeMap(data, common) +} + +func (d *HomeCloud) familyGetFiles(catalogID string) ([]model.Obj, error) { + + // if strings.Contains(catalogID, "/") { + // catalogID = "0" + // } + + pageNum := 1 + files := make([]model.Obj, 0) + for { + data := base.Json{ + "pageInfo": base.Json{ + "pageNum": pageNum, + "pageSize": 100, + }, + "sortInfo": base.Json{ + "sortField": 1, + "sortOrder": 2, + }, + "userId": d.UserID, + "groupId": d.GroupID, + "fileId": catalogID, + } + + //https://homecloud.komect.com/front/storage/getFileInfoList/v1 + var resp QueryContentListResp + _, err := d.post("/storage/getFileInfoList/v1", data, &resp) + if err != nil { + return nil, err + } + for _, content := range resp.Data.FileInfos { + filesize, err := strconv.ParseInt(content.Size, 10, 64) + + if err != nil { + return nil, err + } + + isfolder := false + + if content.Type == 1 { + isfolder = true + } + + ctimestamp, err := strconv.ParseInt(content.CreateTime, 10, 64) + if err != nil { + fmt.Println("Error parsing timestamp:", err) + return nil, err + } + + mtimestamp, err := strconv.ParseInt(content.UpdateTime, 10, 64) + if err != nil { + fmt.Println("Error parsing timestamp:", err) + return nil, err + } + + // 转换为秒和纳秒 + cseconds := ctimestamp / 1000 + cnanoseconds := (ctimestamp % 1000) * 1000000 + + // 转换为秒和纳秒 + mseconds := mtimestamp / 1000 + mnanoseconds := (mtimestamp % 1000) * 1000000 + + // 创建 time.Time 对象 + ct := time.Unix(cseconds, cnanoseconds) + mt := time.Unix(mseconds, mnanoseconds) + + f := model.ObjThumb{ + Object: model.Object{ + ID: content.ID, + Name: content.Name, + Size: filesize, + IsFolder: isfolder, + Modified: mt, + Ctime: ct, + }, + } + files = append(files, &f) + } + + total_count, err := strconv.Atoi(resp.Data.Total) + + if err != nil { + return nil, err + } + + if 100*pageNum > total_count { + break + } + pageNum++ + } + return files, nil +} + +func (d *HomeCloud) getLink(contentId string) (string, error) { + data := base.Json{ + "userId": d.UserID, + "groupId": d.GroupID, + "fileId": contentId, + } + + res, err := d.post("/storage/getFileDownloadUrl/v1", + data, nil) + if err != nil { + return "", err + } + download_url := "https://cdn.homecloud.komect.com/gateway" + jsoniter.Get(res, "data", "downloadUrl").ToString() + return download_url, nil +} + +func unicode(str string) string { + textQuoted := strconv.QuoteToASCII(str) + textUnquoted := textQuoted[1 : len(textQuoted)-1] + return textUnquoted +} + +func (d *HomeCloud) personalRequest(pathname string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + url := "https://personal-kd-njs.yun.139.com" + pathname + req := base.RestyClient.R() + randStr := random.String(16) + ts := time.Now().Format("2006-01-02 15:04:05") + if callback != nil { + callback(req) + } + body, err := utils.Json.Marshal(req.Body) + if err != nil { + return nil, err + } + sign := calSign(string(body), ts, randStr) + svcType := "1" + if d.isFamily() { + svcType = "2" + } + req.SetHeaders(map[string]string{ + "Accept": "application/json, text/plain, */*", + "Authorization": "Basic " + d.AccessToken, + "Caller": "web", + "Cms-Device": "default", + "Mcloud-Channel": "1000101", + "Mcloud-Client": "10701", + "Mcloud-Route": "001", + "Mcloud-Sign": fmt.Sprintf("%s,%s,%s", ts, randStr, sign), + "Mcloud-Version": "7.13.0", + "Origin": "https://yun.139.com", + "Referer": "https://yun.139.com/w/", + "x-DeviceInfo": "||9|7.13.0|chrome|120.0.0.0|||windows 10||zh-CN|||", + "x-huawei-channelSrc": "10000034", + "x-inner-ntwk": "2", + "x-m4c-caller": "PC", + "x-m4c-src": "10002", + "x-SvcType": svcType, + "X-Yun-Api-Version": "v1", + "X-Yun-App-Channel": "10000034", + "X-Yun-Channel-Source": "10000034", + "X-Yun-Client-Info": "||9|7.13.0|chrome|120.0.0.0|||windows 10||zh-CN|||dW5kZWZpbmVk||", + "X-Yun-Module-Type": "100", + "X-Yun-Svc-Type": "1", + }) + + var e BaseResp + req.SetResult(&e) + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + log.Debugln(res.String()) + if !e.Success { + return nil, errors.New(e.Message) + } + if resp != nil { + err = utils.Json.Unmarshal(res.Body(), resp) + if err != nil { + return nil, err + } + } + return res.Body(), nil +} +func (d *HomeCloud) personalPost(pathname string, data interface{}, resp interface{}) ([]byte, error) { + return d.personalRequest(pathname, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, resp) +} + +func getPersonalTime(t string) time.Time { + stamp, err := time.ParseInLocation("2006-01-02T15:04:05.999-07:00", t, utils.CNLoc) + if err != nil { + panic(err) + } + return stamp +} + +func (d *HomeCloud) personalGetFiles(fileId string) ([]model.Obj, error) { + files := make([]model.Obj, 0) + nextPageCursor := "" + for { + data := base.Json{ + "imageThumbnailStyleList": []string{"Small", "Large"}, + "orderBy": "updated_at", + "orderDirection": "DESC", + "pageInfo": base.Json{ + "pageCursor": nextPageCursor, + "pageSize": 100, + }, + "parentFileId": fileId, + } + var resp PersonalListResp + _, err := d.personalPost("/hcy/file/list", data, &resp) + if err != nil { + return nil, err + } + nextPageCursor = resp.Data.NextPageCursor + for _, item := range resp.Data.Items { + var isFolder = (item.Type == "folder") + var f model.Obj + if isFolder { + f = &model.Object{ + ID: item.FileId, + Name: item.Name, + Size: 0, + Modified: getPersonalTime(item.UpdatedAt), + Ctime: getPersonalTime(item.CreatedAt), + IsFolder: isFolder, + } + } else { + var Thumbnails = item.Thumbnails + var ThumbnailUrl string + if len(Thumbnails) > 0 { + ThumbnailUrl = Thumbnails[len(Thumbnails)-1].Url + } + f = &model.ObjThumb{ + Object: model.Object{ + ID: item.FileId, + Name: item.Name, + Size: item.Size, + Modified: getPersonalTime(item.UpdatedAt), + Ctime: getPersonalTime(item.CreatedAt), + IsFolder: isFolder, + }, + Thumbnail: model.Thumbnail{Thumbnail: ThumbnailUrl}, + } + } + files = append(files, f) + } + if len(nextPageCursor) == 0 { + break + } + } + return files, nil +} + +func (d *HomeCloud) personalGetLink(fileId string) (string, error) { + data := base.Json{ + "fileId": fileId, + } + res, err := d.personalPost("/hcy/file/getDownloadUrl", + data, nil) + if err != nil { + return "", err + } + var cdnUrl = jsoniter.Get(res, "data", "cdnUrl").ToString() + if cdnUrl != "" { + return cdnUrl, nil + } else { + return jsoniter.Get(res, "data", "url").ToString(), nil + } +} diff --git a/drivers/ilanzou/driver.go b/drivers/ilanzou/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..90ef7c1a9109e5e948cc78979ad9bcf6ea35b12a --- /dev/null +++ b/drivers/ilanzou/driver.go @@ -0,0 +1,388 @@ +package template + +import ( + "context" + "crypto/md5" + "encoding/base64" + "encoding/hex" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/foxxorcat/mopan-sdk-go" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +type ILanZou struct { + model.Storage + Addition + + userID string + account string + upClient *resty.Client + conf Conf + config driver.Config +} + +func (d *ILanZou) Config() driver.Config { + return d.config +} + +func (d *ILanZou) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *ILanZou) Init(ctx context.Context) error { + d.upClient = base.NewRestyClient().SetTimeout(time.Minute * 10) + if d.UUID == "" { + res, err := d.unproved("/getUuid", http.MethodGet, nil) + if err != nil { + return err + } + d.UUID = utils.Json.Get(res, "uuid").ToString() + } + res, err := d.proved("/user/account/map", http.MethodGet, nil) + if err != nil { + return err + } + d.userID = utils.Json.Get(res, "map", "userId").ToString() + d.account = utils.Json.Get(res, "map", "account").ToString() + log.Debugf("[ilanzou] init response: %s", res) + return nil +} + +func (d *ILanZou) Drop(ctx context.Context) error { + return nil +} + +func (d *ILanZou) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + offset := 1 + var res []ListItem + for { + var resp ListResp + _, err := d.proved("/record/file/list", http.MethodGet, func(req *resty.Request) { + params := []string{ + "offset=" + strconv.Itoa(offset), + "limit=60", + "folderId=" + dir.GetID(), + "type=0", + } + queryString := strings.Join(params, "&") + req.SetQueryString(queryString).SetResult(&resp) + }) + if err != nil { + return nil, err + } + res = append(res, resp.List...) + if resp.Offset < resp.TotalPage { + offset++ + } else { + break + } + } + return utils.SliceConvert(res, func(f ListItem) (model.Obj, error) { + updTime, err := time.ParseInLocation("2006-01-02 15:04:05", f.UpdTime, time.Local) + if err != nil { + return nil, err + } + obj := model.Object{ + ID: strconv.FormatInt(f.FileId, 10), + //Path: "", + Name: f.FileName, + Size: f.FileSize * 1024, + Modified: updTime, + Ctime: updTime, + IsFolder: false, + //HashInfo: utils.HashInfo{}, + } + if f.FileType == 2 { + obj.IsFolder = true + obj.Size = 0 + obj.ID = strconv.FormatInt(f.FolderId, 10) + obj.Name = f.FolderName + } + return &obj, nil + }) +} + +func (d *ILanZou) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + u, err := url.Parse(d.conf.base + "/" + d.conf.unproved + "/file/redirect") + if err != nil { + return nil, err + } + ts, ts_str, err := getTimestamp(d.conf.secret) + + params := []string{ + "uuid=" + url.QueryEscape(d.UUID), + "devType=6", + "devCode=" + url.QueryEscape(d.UUID), + "devModel=chrome", + "devVersion=" + url.QueryEscape(d.conf.devVersion), + "appVersion=", + "timestamp=" + ts_str, + "appToken=" + url.QueryEscape(d.Token), + "enable=0", + } + + downloadId, err := mopan.AesEncrypt([]byte(fmt.Sprintf("%s|%s", file.GetID(), d.userID)), d.conf.secret) + if err != nil { + return nil, err + } + params = append(params, "downloadId="+url.QueryEscape(hex.EncodeToString(downloadId))) + + auth, err := mopan.AesEncrypt([]byte(fmt.Sprintf("%s|%d", file.GetID(), ts)), d.conf.secret) + if err != nil { + return nil, err + } + params = append(params, "auth="+url.QueryEscape(hex.EncodeToString(auth))) + + u.RawQuery = strings.Join(params, "&") + realURL := u.String() + // get the url after redirect + res, err := base.NoRedirectClient.R().SetHeaders(map[string]string{ + //"Origin": d.conf.site, + "Referer": d.conf.site + "/", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36 Edg/125.0.0.0", + }).Get(realURL) + if err != nil { + return nil, err + } + if res.StatusCode() == 302 { + realURL = res.Header().Get("location") + } else { + return nil, fmt.Errorf("redirect failed, status: %d, msg: %s", res.StatusCode(), utils.Json.Get(res.Body(), "msg").ToString()) + } + link := model.Link{URL: realURL} + return &link, nil +} + +func (d *ILanZou) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + res, err := d.proved("/file/folder/save", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "folderDesc": "", + "folderId": parentDir.GetID(), + "folderName": dirName, + }) + }) + if err != nil { + return nil, err + } + return &model.Object{ + ID: utils.Json.Get(res, "list", 0, "id").ToString(), + //Path: "", + Name: dirName, + Size: 0, + Modified: time.Now(), + Ctime: time.Now(), + IsFolder: true, + //HashInfo: utils.HashInfo{}, + }, nil +} + +func (d *ILanZou) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + var fileIds, folderIds []string + if srcObj.IsDir() { + folderIds = []string{srcObj.GetID()} + } else { + fileIds = []string{srcObj.GetID()} + } + _, err := d.proved("/file/folder/move", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "folderIds": strings.Join(folderIds, ","), + "fileIds": strings.Join(fileIds, ","), + "targetId": dstDir.GetID(), + }) + }) + if err != nil { + return nil, err + } + return srcObj, nil +} + +func (d *ILanZou) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + var err error + if srcObj.IsDir() { + _, err = d.proved("/file/folder/edit", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "folderDesc": "", + "folderId": srcObj.GetID(), + "folderName": newName, + }) + }) + } else { + _, err = d.proved("/file/edit", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "fileDesc": "", + "fileId": srcObj.GetID(), + "fileName": newName, + }) + }) + } + if err != nil { + return nil, err + } + return &model.Object{ + ID: srcObj.GetID(), + //Path: "", + Name: newName, + Size: srcObj.GetSize(), + Modified: time.Now(), + Ctime: srcObj.CreateTime(), + IsFolder: srcObj.IsDir(), + }, nil +} + +func (d *ILanZou) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + // TODO copy obj, optional + return nil, errs.NotImplement +} + +func (d *ILanZou) Remove(ctx context.Context, obj model.Obj) error { + var fileIds, folderIds []string + if obj.IsDir() { + folderIds = []string{obj.GetID()} + } else { + fileIds = []string{obj.GetID()} + } + _, err := d.proved("/file/delete", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "folderIds": strings.Join(folderIds, ","), + "fileIds": strings.Join(fileIds, ","), + "status": 0, + }) + }) + return err +} + +const DefaultPartSize = 1024 * 1024 * 8 + +func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + h := md5.New() + // need to calculate md5 of the full content + tempFile, err := stream.CacheFullInTempFile() + if err != nil { + return nil, err + } + defer func() { + _ = tempFile.Close() + }() + if _, err = utils.CopyWithBuffer(h, tempFile); err != nil { + return nil, err + } + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return nil, err + } + etag := hex.EncodeToString(h.Sum(nil)) + // get upToken + res, err := d.proved("/7n/getUpToken", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "fileId": "", + "fileName": stream.GetName(), + "fileSize": stream.GetSize()/1024 + 1, + "folderId": dstDir.GetID(), + "md5": etag, + "type": 1, + }) + }) + if err != nil { + return nil, err + } + upToken := utils.Json.Get(res, "upToken").ToString() + now := time.Now() + key := fmt.Sprintf("disk/%d/%d/%d/%s/%016d", now.Year(), now.Month(), now.Day(), d.account, now.UnixMilli()) + var token string + if stream.GetSize() <= DefaultPartSize { + res, err := d.upClient.R().SetMultipartFormData(map[string]string{ + "token": upToken, + "key": key, + "fname": stream.GetName(), + }).SetMultipartField("file", stream.GetName(), stream.GetMimetype(), tempFile). + Post("https://upload.qiniup.com/") + if err != nil { + return nil, err + } + token = utils.Json.Get(res.Body(), "token").ToString() + } else { + keyBase64 := base64.URLEncoding.EncodeToString([]byte(key)) + res, err := d.upClient.R().SetHeader("Authorization", "UpToken "+upToken).Post(fmt.Sprintf("https://upload.qiniup.com/buckets/%s/objects/%s/uploads", d.conf.bucket, keyBase64)) + if err != nil { + return nil, err + } + uploadId := utils.Json.Get(res.Body(), "uploadId").ToString() + parts := make([]Part, 0) + partNum := (stream.GetSize() + DefaultPartSize - 1) / DefaultPartSize + for i := 1; i <= int(partNum); i++ { + u := fmt.Sprintf("https://upload.qiniup.com/buckets/%s/objects/%s/uploads/%s/%d", d.conf.bucket, keyBase64, uploadId, i) + res, err = d.upClient.R().SetHeader("Authorization", "UpToken "+upToken).SetBody(io.LimitReader(tempFile, DefaultPartSize)).Put(u) + if err != nil { + return nil, err + } + etag := utils.Json.Get(res.Body(), "etag").ToString() + parts = append(parts, Part{ + PartNumber: i, + ETag: etag, + }) + } + res, err = d.upClient.R().SetHeader("Authorization", "UpToken "+upToken).SetBody(base.Json{ + "fnmae": stream.GetName(), + "parts": parts, + }).Post(fmt.Sprintf("https://upload.qiniup.com/buckets/%s/objects/%s/uploads/%s", d.conf.bucket, keyBase64, uploadId)) + if err != nil { + return nil, err + } + token = utils.Json.Get(res.Body(), "token").ToString() + } + // commit upload + var resp UploadResultResp + for i := 0; i < 10; i++ { + _, err = d.unproved("/7n/results", http.MethodPost, func(req *resty.Request) { + params := []string{ + "tokenList=" + token, + "tokenTime=" + time.Now().Format("Mon Jan 02 2006 15:04:05 GMT-0700 (MST)"), + } + queryString := strings.Join(params, "&") + req.SetQueryString(queryString).SetResult(&resp) + }) + if err != nil { + return nil, err + } + if len(resp.List) == 0 { + return nil, fmt.Errorf("upload failed, empty response") + } + if resp.List[0].Status == 1 { + break + } + time.Sleep(time.Second * 1) + } + file := resp.List[0] + if file.Status != 1 { + return nil, fmt.Errorf("upload failed, status: %d", resp.List[0].Status) + } + return &model.Object{ + ID: strconv.FormatInt(file.FileId, 10), + //Path: , + Name: file.FileName, + Size: stream.GetSize(), + Modified: stream.ModTime(), + Ctime: stream.CreateTime(), + IsFolder: false, + HashInfo: utils.NewHashInfo(utils.MD5, etag), + }, nil +} + +//func (d *ILanZou) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*ILanZou)(nil) diff --git a/drivers/ilanzou/meta.go b/drivers/ilanzou/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..f15fc01a492bfc9082d4dc8311e760e2da56ab8e --- /dev/null +++ b/drivers/ilanzou/meta.go @@ -0,0 +1,80 @@ +package template + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootID + Username string `json:"username" type:"string" required:"true"` + Password string `json:"password" type:"string" required:"true"` + + Token string + UUID string +} + +type Conf struct { + base string + secret []byte + bucket string + unproved string + proved string + devVersion string + site string +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &ILanZou{ + config: driver.Config{ + Name: "ILanZou", + LocalSort: false, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: false, + NeedMs: false, + DefaultRoot: "0", + CheckStatus: false, + Alert: "", + NoOverwriteUpload: false, + }, + conf: Conf{ + base: "https://api.ilanzou.com", + secret: []byte("lanZouY-disk-app"), + bucket: "wpanstore-lanzou", + unproved: "unproved", + proved: "proved", + devVersion: "125", + site: "https://www.ilanzou.com", + }, + } + }) + op.RegisterDriver(func() driver.Driver { + return &ILanZou{ + config: driver.Config{ + Name: "FeijiPan", + LocalSort: false, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: false, + NeedMs: false, + DefaultRoot: "0", + CheckStatus: false, + Alert: "", + NoOverwriteUpload: false, + }, + conf: Conf{ + base: "https://api.feijipan.com", + secret: []byte("dingHao-disk-app"), + bucket: "wpanstore", + unproved: "ws", + proved: "app", + devVersion: "125", + site: "https://www.feijipan.com", + }, + } + }) +} diff --git a/drivers/ilanzou/types.go b/drivers/ilanzou/types.go new file mode 100644 index 0000000000000000000000000000000000000000..135724c749cc29fc72dfeef03a36819ab9a3dbdc --- /dev/null +++ b/drivers/ilanzou/types.go @@ -0,0 +1,57 @@ +package template + +type ListResp struct { + Msg string `json:"msg"` + Total int `json:"total"` + Code int `json:"code"` + Offset int `json:"offset"` + TotalPage int `json:"totalPage"` + Limit int `json:"limit"` + List []ListItem `json:"list"` +} + +type ListItem struct { + IconId int `json:"iconId"` + IsAmt int `json:"isAmt"` + FolderDesc string `json:"folderDesc,omitempty"` + AddTime string `json:"addTime"` + FolderId int64 `json:"folderId"` + ParentId int64 `json:"parentId"` + ParentName string `json:"parentName"` + NoteType int `json:"noteType,omitempty"` + UpdTime string `json:"updTime"` + IsShare int `json:"isShare"` + FolderIcon string `json:"folderIcon,omitempty"` + FolderName string `json:"folderName,omitempty"` + FileType int `json:"fileType"` + Status int `json:"status"` + IsFileShare int `json:"isFileShare,omitempty"` + FileName string `json:"fileName,omitempty"` + FileStars float64 `json:"fileStars,omitempty"` + IsFileDownload int `json:"isFileDownload,omitempty"` + FileComments int `json:"fileComments,omitempty"` + FileSize int64 `json:"fileSize,omitempty"` + FileIcon string `json:"fileIcon,omitempty"` + FileDownloads int `json:"fileDownloads,omitempty"` + FileUrl interface{} `json:"fileUrl"` + FileLikes int `json:"fileLikes,omitempty"` + FileId int64 `json:"fileId,omitempty"` +} + +type Part struct { + PartNumber int `json:"partNumber"` + ETag string `json:"etag"` +} + +type UploadResultResp struct { + Msg string `json:"msg"` + Code int `json:"code"` + List []struct { + FileIconId int `json:"fileIconId"` + FileName string `json:"fileName"` + FileIcon string `json:"fileIcon"` + FileId int64 `json:"fileId"` + Status int `json:"status"` + Token string `json:"token"` + } `json:"list"` +} diff --git a/drivers/ilanzou/util.go b/drivers/ilanzou/util.go new file mode 100644 index 0000000000000000000000000000000000000000..a57e2a4a6bec489c9c9dac0e12ead89bde583ddc --- /dev/null +++ b/drivers/ilanzou/util.go @@ -0,0 +1,111 @@ +package template + +import ( + "encoding/hex" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/foxxorcat/mopan-sdk-go" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +func (d *ILanZou) login() error { + res, err := d.unproved("/login", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "loginName": d.Username, + "loginPwd": d.Password, + }) + }) + if err != nil { + return err + } + d.Token = utils.Json.Get(res, "data", "appToken").ToString() + if d.Token == "" { + return fmt.Errorf("failed to login: token is empty, resp: %s", res) + } + return nil +} + +func getTimestamp(secret []byte) (int64, string, error) { + ts := time.Now().UnixMilli() + tsStr := strconv.FormatInt(ts, 10) + res, err := mopan.AesEncrypt([]byte(tsStr), secret) + if err != nil { + return 0, "", err + } + return ts, hex.EncodeToString(res), nil +} + +func (d *ILanZou) request(pathname, method string, callback base.ReqCallback, proved bool, retry ...bool) ([]byte, error) { + _, ts_str, err := getTimestamp(d.conf.secret) + if err != nil { + return nil, err + } + + params := []string{ + "uuid=" + url.QueryEscape(d.UUID), + "devType=6", + "devCode=" + url.QueryEscape(d.UUID), + "devModel=chrome", + "devVersion=" + url.QueryEscape(d.conf.devVersion), + "appVersion=", + "timestamp=" + ts_str, + } + + if proved { + params = append(params, "appToken="+url.QueryEscape(d.Token)) + } + + params = append(params, "extra=2") + + queryString := strings.Join(params, "&") + + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "Origin": d.conf.site, + "Referer": d.conf.site + "/", + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36 Edg/125.0.0.0", + }) + + if callback != nil { + callback(req) + } + + res, err := req.Execute(method, d.conf.base+pathname+"?"+queryString) + if err != nil { + if res != nil { + log.Errorf("[iLanZou] request error: %s", res.String()) + } + return nil, err + } + isRetry := len(retry) > 0 && retry[0] + body := res.Body() + code := utils.Json.Get(body, "code").ToInt() + msg := utils.Json.Get(body, "msg").ToString() + if code != 200 { + if !isRetry && proved && (utils.SliceContains([]int{-1, -2}, code) || d.Token == "") { + err = d.login() + if err != nil { + return nil, err + } + return d.request(pathname, method, callback, proved, true) + } + return nil, fmt.Errorf("%d: %s", code, msg) + } + return body, nil +} + +func (d *ILanZou) unproved(pathname, method string, callback base.ReqCallback) ([]byte, error) { + return d.request("/"+d.conf.unproved+pathname, method, callback, false) +} + +func (d *ILanZou) proved(pathname, method string, callback base.ReqCallback) ([]byte, error) { + return d.request("/"+d.conf.proved+pathname, method, callback, true) +} diff --git a/drivers/ipfs_api/driver.go b/drivers/ipfs_api/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..f6f81305e2021398483663701c57ba6b18f51fc9 --- /dev/null +++ b/drivers/ipfs_api/driver.go @@ -0,0 +1,128 @@ +package ipfs + +import ( + "context" + "fmt" + "net/url" + stdpath "path" + "path/filepath" + "strings" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + shell "github.com/ipfs/go-ipfs-api" +) + +type IPFS struct { + model.Storage + Addition + sh *shell.Shell + gateURL *url.URL +} + +func (d *IPFS) Config() driver.Config { + return config +} + +func (d *IPFS) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *IPFS) Init(ctx context.Context) error { + d.sh = shell.NewShell(d.Endpoint) + gateURL, err := url.Parse(d.Gateway) + if err != nil { + return err + } + d.gateURL = gateURL + return nil +} + +func (d *IPFS) Drop(ctx context.Context) error { + return nil +} + +func (d *IPFS) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + path := dir.GetPath() + if path[len(path):] != "/" { + path += "/" + } + + path_cid, err := d.sh.FilesStat(ctx, path) + if err != nil { + return nil, err + } + + dirs, err := d.sh.List(path_cid.Hash) + if err != nil { + return nil, err + } + + objlist := []model.Obj{} + for _, file := range dirs { + gateurl := *d.gateURL + gateurl.Path = "ipfs/" + file.Hash + gateurl.RawQuery = "filename=" + url.PathEscape(file.Name) + objlist = append(objlist, &model.ObjectURL{ + Object: model.Object{ID: file.Hash, Name: file.Name, Size: int64(file.Size), IsFolder: file.Type == 1}, + Url: model.Url{Url: gateurl.String()}, + }) + } + + return objlist, nil +} + +func (d *IPFS) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + link := d.Gateway + "/ipfs/" + file.GetID() + "/?filename=" + url.PathEscape(file.GetName()) + return &model.Link{URL: link}, nil +} + +func (d *IPFS) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + path := parentDir.GetPath() + if path[len(path):] != "/" { + path += "/" + } + return d.sh.FilesMkdir(ctx, path+dirName) +} + +func (d *IPFS) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + return d.sh.FilesMv(ctx, srcObj.GetPath(), dstDir.GetPath()) +} + +func (d *IPFS) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + newFileName := filepath.Dir(srcObj.GetPath()) + "/" + newName + return d.sh.FilesMv(ctx, srcObj.GetPath(), strings.ReplaceAll(newFileName, "\\", "/")) +} + +func (d *IPFS) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + // TODO copy obj, optional + fmt.Println(srcObj.GetPath()) + fmt.Println(dstDir.GetPath()) + newFileName := dstDir.GetPath() + "/" + filepath.Base(srcObj.GetPath()) + fmt.Println(newFileName) + return d.sh.FilesCp(ctx, srcObj.GetPath(), strings.ReplaceAll(newFileName, "\\", "/")) +} + +func (d *IPFS) Remove(ctx context.Context, obj model.Obj) error { + // TODO remove obj, optional + return d.sh.FilesRm(ctx, obj.GetPath(), true) +} + +func (d *IPFS) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + // TODO upload file, optional + _, err := d.sh.Add(stream, ToFiles(stdpath.Join(dstDir.GetPath(), stream.GetName()))) + return err +} + +func ToFiles(dstDir string) shell.AddOpts { + return func(rb *shell.RequestBuilder) error { + rb.Option("to-files", dstDir) + return nil + } +} + +//func (d *Template) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*IPFS)(nil) diff --git a/drivers/ipfs_api/meta.go b/drivers/ipfs_api/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..cdc3042434bc61d9a370e52b740d87412f35e023 --- /dev/null +++ b/drivers/ipfs_api/meta.go @@ -0,0 +1,25 @@ +package ipfs + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + // Usually one of two + driver.RootPath + Endpoint string `json:"endpoint" default:"http://127.0.0.1:5001"` + Gateway string `json:"gateway" default:"https://ipfs.io"` +} + +var config = driver.Config{ + Name: "IPFS API", + DefaultRoot: "/", + LocalSort: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &IPFS{} + }) +} diff --git a/drivers/kodbox/driver.go b/drivers/kodbox/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..eb5120a67c11de0ec770cf66f5af4859fc4fe714 --- /dev/null +++ b/drivers/kodbox/driver.go @@ -0,0 +1,273 @@ +package kodbox + +import ( + "context" + "fmt" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + "net/http" + "path/filepath" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" +) + +type KodBox struct { + model.Storage + Addition + authorization string +} + +func (d *KodBox) Config() driver.Config { + return config +} + +func (d *KodBox) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *KodBox) Init(ctx context.Context) error { + d.Address = strings.TrimSuffix(d.Address, "/") + d.RootFolderPath = strings.TrimPrefix(utils.FixAndCleanPath(d.RootFolderPath), "/") + return d.getToken() +} + +func (d *KodBox) Drop(ctx context.Context) error { + return nil +} + +func (d *KodBox) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + var ( + resp *CommonResp + listPathData *ListPathData + ) + + _, err := d.request(http.MethodPost, "/?explorer/list/path", func(req *resty.Request) { + req.SetResult(&resp).SetFormData(map[string]string{ + "path": dir.GetPath(), + }) + }, true) + if err != nil { + return nil, err + } + + dataBytes, err := utils.Json.Marshal(resp.Data) + if err != nil { + return nil, err + } + + err = utils.Json.Unmarshal(dataBytes, &listPathData) + if err != nil { + return nil, err + } + FolderAndFiles := append(listPathData.FolderList, listPathData.FileList...) + + return utils.SliceConvert(FolderAndFiles, func(f FolderOrFile) (model.Obj, error) { + return &model.ObjThumb{ + Object: model.Object{ + Path: f.Path, + Name: f.Name, + Ctime: time.Unix(f.CreateTime, 0), + Modified: time.Unix(f.ModifyTime, 0), + Size: f.Size, + IsFolder: f.Type == "folder", + }, + //Thumbnail: model.Thumbnail{}, + }, nil + }) +} + +func (d *KodBox) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + path := file.GetPath() + return &model.Link{ + URL: fmt.Sprintf("%s/?explorer/index/fileOut&path=%s&download=1&accessToken=%s", + d.Address, + path, + d.authorization)}, nil +} + +func (d *KodBox) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + var resp *CommonResp + newDirPath := filepath.Join(parentDir.GetPath(), dirName) + + _, err := d.request(http.MethodPost, "/?explorer/index/mkdir", func(req *resty.Request) { + req.SetResult(&resp).SetFormData(map[string]string{ + "path": newDirPath, + }) + }) + if err != nil { + return nil, err + } + code := resp.Code.(bool) + if !code { + return nil, fmt.Errorf("%s", resp.Data) + } + + return &model.ObjThumb{ + Object: model.Object{ + Path: resp.Info.(string), + Name: dirName, + IsFolder: true, + Modified: time.Now(), + Ctime: time.Now(), + }, + }, nil +} + +func (d *KodBox) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + var resp *CommonResp + _, err := d.request(http.MethodPost, "/?explorer/index/pathCuteTo", func(req *resty.Request) { + req.SetResult(&resp).SetFormData(map[string]string{ + "dataArr": fmt.Sprintf("[{\"path\": \"%s\", \"name\": \"%s\"}]", + srcObj.GetPath(), + srcObj.GetName()), + "path": dstDir.GetPath(), + }) + }, true) + if err != nil { + return nil, err + } + code := resp.Code.(bool) + if !code { + return nil, fmt.Errorf("%s", resp.Data) + } + + return &model.ObjThumb{ + Object: model.Object{ + Path: srcObj.GetPath(), + Name: srcObj.GetName(), + IsFolder: srcObj.IsDir(), + Modified: srcObj.ModTime(), + Ctime: srcObj.CreateTime(), + }, + }, nil +} + +func (d *KodBox) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + var resp *CommonResp + _, err := d.request(http.MethodPost, "/?explorer/index/pathRename", func(req *resty.Request) { + req.SetResult(&resp).SetFormData(map[string]string{ + "path": srcObj.GetPath(), + "newName": newName, + }) + }, true) + if err != nil { + return nil, err + } + code := resp.Code.(bool) + if !code { + return nil, fmt.Errorf("%s", resp.Data) + } + return &model.ObjThumb{ + Object: model.Object{ + Path: srcObj.GetPath(), + Name: newName, + IsFolder: srcObj.IsDir(), + Modified: time.Now(), + Ctime: srcObj.CreateTime(), + }, + }, nil +} + +func (d *KodBox) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + var resp *CommonResp + _, err := d.request(http.MethodPost, "/?explorer/index/pathCopyTo", func(req *resty.Request) { + req.SetResult(&resp).SetFormData(map[string]string{ + "dataArr": fmt.Sprintf("[{\"path\": \"%s\", \"name\": \"%s\"}]", + srcObj.GetPath(), + srcObj.GetName()), + "path": dstDir.GetPath(), + }) + }) + if err != nil { + return nil, err + } + code := resp.Code.(bool) + if !code { + return nil, fmt.Errorf("%s", resp.Data) + } + + path := resp.Info.([]interface{})[0].(string) + objectName, err := d.getFileOrFolderName(ctx, path) + if err != nil { + return nil, err + } + return &model.ObjThumb{ + Object: model.Object{ + Path: path, + Name: *objectName, + IsFolder: srcObj.IsDir(), + Modified: time.Now(), + Ctime: time.Now(), + }, + }, nil +} + +func (d *KodBox) Remove(ctx context.Context, obj model.Obj) error { + var resp *CommonResp + _, err := d.request(http.MethodPost, "/?explorer/index/pathDelete", func(req *resty.Request) { + req.SetResult(&resp).SetFormData(map[string]string{ + "dataArr": fmt.Sprintf("[{\"path\": \"%s\", \"name\": \"%s\"}]", + obj.GetPath(), + obj.GetName()), + "shiftDelete": "1", + }) + }) + if err != nil { + return err + } + code := resp.Code.(bool) + if !code { + return fmt.Errorf("%s", resp.Data) + } + return nil +} + +func (d *KodBox) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + var resp *CommonResp + _, err := d.request(http.MethodPost, "/?explorer/upload/fileUpload", func(req *resty.Request) { + req.SetFileReader("file", stream.GetName(), stream). + SetResult(&resp). + SetFormData(map[string]string{ + "path": dstDir.GetPath(), + }) + }) + if err != nil { + return nil, err + } + code := resp.Code.(bool) + if !code { + return nil, fmt.Errorf("%s", resp.Data) + } + return &model.ObjThumb{ + Object: model.Object{ + Path: resp.Info.(string), + Name: stream.GetName(), + Size: stream.GetSize(), + IsFolder: false, + Modified: time.Now(), + Ctime: time.Now(), + }, + }, nil +} + +func (d *KodBox) getFileOrFolderName(ctx context.Context, path string) (*string, error) { + var resp *CommonResp + _, err := d.request(http.MethodPost, "/?explorer/index/pathInfo", func(req *resty.Request) { + req.SetResult(&resp).SetFormData(map[string]string{ + "dataArr": fmt.Sprintf("[{\"path\": \"%s\"}]", path)}) + }) + if err != nil { + return nil, err + } + code := resp.Code.(bool) + if !code { + return nil, fmt.Errorf("%s", resp.Data) + } + folderOrFileName := resp.Data.(map[string]any)["name"].(string) + return &folderOrFileName, nil +} + +var _ driver.Driver = (*KodBox)(nil) diff --git a/drivers/kodbox/meta.go b/drivers/kodbox/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..318fb9ec56f5e4f7a9f275bb03ecced9c4b2893e --- /dev/null +++ b/drivers/kodbox/meta.go @@ -0,0 +1,25 @@ +package kodbox + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + + Address string `json:"address" required:"true"` + UserName string `json:"username" required:"false"` + Password string `json:"password" required:"false"` +} + +var config = driver.Config{ + Name: "KodBox", + DefaultRoot: "", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &KodBox{} + }) +} diff --git a/drivers/kodbox/types.go b/drivers/kodbox/types.go new file mode 100644 index 0000000000000000000000000000000000000000..9bd45d9b366c7e657ef147111eca04ba30161172 --- /dev/null +++ b/drivers/kodbox/types.go @@ -0,0 +1,24 @@ +package kodbox + +type CommonResp struct { + Code any `json:"code"` + TimeUse string `json:"timeUse"` + TimeNow string `json:"timeNow"` + Data any `json:"data"` + Info any `json:"info"` +} + +type ListPathData struct { + FolderList []FolderOrFile `json:"folderList"` + FileList []FolderOrFile `json:"fileList"` +} + +type FolderOrFile struct { + Name string `json:"name"` + Path string `json:"path"` + Type string `json:"type"` + Ext string `json:"ext,omitempty"` // 文件特有字段 + Size int64 `json:"size"` + CreateTime int64 `json:"createTime"` + ModifyTime int64 `json:"modifyTime"` +} diff --git a/drivers/kodbox/util.go b/drivers/kodbox/util.go new file mode 100644 index 0000000000000000000000000000000000000000..2c04cd73f29bfebf766ac9fbf6504bcc229f95d4 --- /dev/null +++ b/drivers/kodbox/util.go @@ -0,0 +1,86 @@ +package kodbox + +import ( + "fmt" + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + "strings" +) + +func (d *KodBox) getToken() error { + var authResp CommonResp + res, err := base.RestyClient.R(). + SetResult(&authResp). + SetQueryParams(map[string]string{ + "name": d.UserName, + "password": d.Password, + }). + Post(d.Address + "/?user/index/loginSubmit") + if err != nil { + return err + } + if res.StatusCode() >= 400 { + return fmt.Errorf("get token failed: %s", res.String()) + } + + if res.StatusCode() == 200 && authResp.Code.(bool) == false { + return fmt.Errorf("get token failed: %s", res.String()) + } + + d.authorization = fmt.Sprintf("%s", authResp.Info) + return nil +} + +func (d *KodBox) request(method string, pathname string, callback base.ReqCallback, noRedirect ...bool) ([]byte, error) { + full := pathname + if !strings.HasPrefix(pathname, "http") { + full = d.Address + pathname + } + req := base.RestyClient.R() + if len(noRedirect) > 0 && noRedirect[0] { + req = base.NoRedirectClient.R() + } + req.SetFormData(map[string]string{ + "accessToken": d.authorization, + }) + callback(req) + + var ( + res *resty.Response + commonResp *CommonResp + err error + skip bool + ) + for i := 0; i < 2; i++ { + if skip { + break + } + res, err = req.Execute(method, full) + if err != nil { + return nil, err + } + + err := utils.Json.Unmarshal(res.Body(), &commonResp) + if err != nil { + return nil, err + } + + switch commonResp.Code.(type) { + case bool: + skip = true + case string: + if commonResp.Code.(string) == "10001" { + err = d.getToken() + if err != nil { + return nil, err + } + req.SetFormData(map[string]string{"accessToken": d.authorization}) + } + } + } + if commonResp.Code.(bool) == false { + return nil, fmt.Errorf("request failed: %s", commonResp.Data) + } + return res.Body(), nil +} diff --git a/drivers/lanzou/driver.go b/drivers/lanzou/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..9e73f0525c2da9f05c7f645a171f7377abadcd4e --- /dev/null +++ b/drivers/lanzou/driver.go @@ -0,0 +1,230 @@ +package lanzou + +import ( + "context" + "net/http" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +type LanZou struct { + Addition + model.Storage + uid string + vei string + + flag int32 +} + +func (d *LanZou) Config() driver.Config { + return config +} + +func (d *LanZou) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *LanZou) Init(ctx context.Context) (err error) { + if d.UserAgent == "" { + d.UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.39 (KHTML, like Gecko) Chrome/89.0.4389.111 Safari/537.39" + } + switch d.Type { + case "account": + _, err := d.Login() + if err != nil { + return err + } + fallthrough + case "cookie": + if d.RootFolderID == "" { + d.RootFolderID = "-1" + } + d.vei, d.uid, err = d.getVeiAndUid() + } + return +} + +func (d *LanZou) Drop(ctx context.Context) error { + d.uid = "" + return nil +} + +// 获取的大小和时间不准确 +func (d *LanZou) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if d.IsCookie() || d.IsAccount() { + return d.GetAllFiles(dir.GetID()) + } else { + return d.GetFileOrFolderByShareUrl(dir.GetID(), d.SharePassword) + } +} + +func (d *LanZou) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var ( + err error + dfile *FileOrFolderByShareUrl + ) + switch file := file.(type) { + case *FileOrFolder: + // 先获取分享链接 + sfile := file.GetShareInfo() + if sfile == nil { + sfile, err = d.getFileShareUrlByID(file.GetID()) + if err != nil { + return nil, err + } + file.SetShareInfo(sfile) + } + + // 然后获取下载链接 + dfile, err = d.GetFilesByShareUrl(sfile.FID, sfile.Pwd) + if err != nil { + return nil, err + } + // 修复文件大小 + if d.RepairFileInfo && !file.repairFlag { + size, time := d.getFileRealInfo(dfile.Url) + if size != nil { + file.size = size + file.repairFlag = true + } + if file.time != nil { + file.time = time + } + } + case *FileOrFolderByShareUrl: + dfile, err = d.GetFilesByShareUrl(file.GetID(), file.Pwd) + if err != nil { + return nil, err + } + // 修复文件大小 + if d.RepairFileInfo && !file.repairFlag { + size, time := d.getFileRealInfo(dfile.Url) + if size != nil { + file.size = size + file.repairFlag = true + } + if file.time != nil { + file.time = time + } + } + } + exp := GetExpirationTime(dfile.Url) + return &model.Link{ + URL: dfile.Url, + Header: http.Header{ + "User-Agent": []string{base.UserAgent}, + }, + Expiration: &exp, + }, nil +} + +func (d *LanZou) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + if d.IsCookie() || d.IsAccount() { + data, err := d.doupload(func(req *resty.Request) { + req.SetContext(ctx) + req.SetFormData(map[string]string{ + "task": "2", + "parent_id": parentDir.GetID(), + "folder_name": dirName, + "folder_description": "", + }) + }, nil) + if err != nil { + return nil, err + } + return &FileOrFolder{ + Name: dirName, + FolID: utils.Json.Get(data, "text").ToString(), + }, nil + } + return nil, errs.NotSupport +} + +func (d *LanZou) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + if d.IsCookie() || d.IsAccount() { + if !srcObj.IsDir() { + _, err := d.doupload(func(req *resty.Request) { + req.SetContext(ctx) + req.SetFormData(map[string]string{ + "task": "20", + "folder_id": dstDir.GetID(), + "file_id": srcObj.GetID(), + }) + }, nil) + if err != nil { + return nil, err + } + return srcObj, nil + } + } + return nil, errs.NotSupport +} + +func (d *LanZou) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + if d.IsCookie() || d.IsAccount() { + if !srcObj.IsDir() { + _, err := d.doupload(func(req *resty.Request) { + req.SetContext(ctx) + req.SetFormData(map[string]string{ + "task": "46", + "file_id": srcObj.GetID(), + "file_name": newName, + "type": "2", + }) + }, nil) + if err != nil { + return nil, err + } + srcObj.(*FileOrFolder).NameAll = newName + return srcObj, nil + } + } + return nil, errs.NotSupport +} + +func (d *LanZou) Remove(ctx context.Context, obj model.Obj) error { + if d.IsCookie() || d.IsAccount() { + _, err := d.doupload(func(req *resty.Request) { + req.SetContext(ctx) + if obj.IsDir() { + req.SetFormData(map[string]string{ + "task": "3", + "folder_id": obj.GetID(), + }) + } else { + req.SetFormData(map[string]string{ + "task": "6", + "file_id": obj.GetID(), + }) + } + }, nil) + return err + } + return errs.NotSupport +} + +func (d *LanZou) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + if d.IsCookie() || d.IsAccount() { + var resp RespText[[]FileOrFolder] + _, err := d._post(d.BaseUrl+"/html5up.php", func(req *resty.Request) { + req.SetFormData(map[string]string{ + "task": "1", + "vie": "2", + "ve": "2", + "id": "WU_FILE_0", + "name": stream.GetName(), + "folder_id_bb_n": dstDir.GetID(), + }).SetFileReader("upload_file", stream.GetName(), stream).SetContext(ctx) + }, &resp, true) + if err != nil { + return nil, err + } + return &resp.Text[0], nil + } + return nil, errs.NotSupport +} diff --git a/drivers/lanzou/help.go b/drivers/lanzou/help.go new file mode 100644 index 0000000000000000000000000000000000000000..81d7c567d5c1d9fcd6bcf6dc102be9d533a7ef11 --- /dev/null +++ b/drivers/lanzou/help.go @@ -0,0 +1,293 @@ +package lanzou + +import ( + "bytes" + "fmt" + "net/http" + "regexp" + "strconv" + "strings" + "time" + "unicode" + + log "github.com/sirupsen/logrus" +) + +const DAY time.Duration = 84600000000000 + +// 解析时间 +var timeSplitReg = regexp.MustCompile("([0-9.]*)\\s*([\u4e00-\u9fa5]+)") + +// 如果解析失败,则返回当前时间 +func MustParseTime(str string) time.Time { + lastOpTime, err := time.ParseInLocation("2006-01-02 -07", str+" +08", time.Local) + if err != nil { + strs := timeSplitReg.FindStringSubmatch(str) + lastOpTime = time.Now() + if len(strs) == 3 { + i, _ := strconv.ParseInt(strs[1], 10, 64) + ti := time.Duration(-i) + switch strs[2] { + case "秒前": + lastOpTime = lastOpTime.Add(time.Second * ti) + case "分钟前": + lastOpTime = lastOpTime.Add(time.Minute * ti) + case "小时前": + lastOpTime = lastOpTime.Add(time.Hour * ti) + case "天前": + lastOpTime = lastOpTime.Add(DAY * ti) + case "昨天": + lastOpTime = lastOpTime.Add(-DAY) + case "前天": + lastOpTime = lastOpTime.Add(-DAY * 2) + } + } + } + return lastOpTime +} + +// 解析大小 +var sizeSplitReg = regexp.MustCompile(`(?i)([0-9.]+)\s*([bkm]+)`) + +// 解析失败返回0 +func SizeStrToInt64(size string) int64 { + strs := sizeSplitReg.FindStringSubmatch(size) + if len(strs) < 3 { + return 0 + } + + s, _ := strconv.ParseFloat(strs[1], 64) + switch strings.ToUpper(strs[2]) { + case "B": + return int64(s) + case "K": + return int64(s * (1 << 10)) + case "M": + return int64(s * (1 << 20)) + } + return 0 +} + +// 移除注释 +func RemoveNotes(html string) string { + return regexp.MustCompile(`|[^:]//.*|/\*.*?\*/`).ReplaceAllStringFunc(html, func(b string) string { + if b[1:3] == "//" { + return b[:1] + } + return "\n" + }) +} + +var findAcwScV2Reg = regexp.MustCompile(`arg1='([0-9A-Z]+)'`) + +// 在页面被过多访问或其他情况下,有时候会先返回一个加密的页面,其执行计算出一个acw_sc__v2后放入页面后再重新访问页面才能获得正常页面 +// 若该页面进行了js加密,则进行解密,计算acw_sc__v2,并加入cookie +func CalcAcwScV2(html string) (string, error) { + log.Debugln("acw_sc__v2", html) + acwScV2s := findAcwScV2Reg.FindStringSubmatch(html) + if len(acwScV2s) != 2 { + return "", fmt.Errorf("无法匹配acw_sc__v2") + } + return HexXor(Unbox(acwScV2s[1]), "3000176000856006061501533003690027800375"), nil +} + +func Unbox(hex string) string { + var box = []int{6, 28, 34, 31, 33, 18, 30, 23, 9, 8, 19, 38, 17, 24, 0, 5, 32, 21, 10, 22, 25, 14, 15, 3, 16, 27, 13, 35, 2, 29, 11, 26, 4, 36, 1, 39, 37, 7, 20, 12} + var newBox = make([]byte, len(hex)) + for i := 0; i < len(box); i++ { + j := box[i] + if len(newBox) > j { + newBox[j] = hex[i] + } + } + return string(newBox) +} + +func HexXor(hex1, hex2 string) string { + out := bytes.NewBuffer(make([]byte, len(hex1))) + for i := 0; i < len(hex1) && i < len(hex2); i += 2 { + v1, _ := strconv.ParseInt(hex1[i:i+2], 16, 64) + v2, _ := strconv.ParseInt(hex2[i:i+2], 16, 64) + out.WriteString(strconv.FormatInt(v1^v2, 16)) + } + return out.String() +} + +var findDataReg = regexp.MustCompile(`data[:\s]+({[^}]+})`) // 查找json +var findKVReg = regexp.MustCompile(`'(.+?)':('?([^' },]*)'?)`) // 拆分kv + +// 根据key查询js变量 +func findJSVarFunc(key, data string) string { + var values []string + if key != "sasign" { + values = regexp.MustCompile(`var ` + key + `\s*=\s*['"]?(.+?)['"]?;`).FindStringSubmatch(data) + } else { + matches := regexp.MustCompile(`var `+key+`\s*=\s*['"]?(.+?)['"]?;`).FindAllStringSubmatch(data, -1) + if len(matches) == 3 { + values = matches[1] + } else { + if len(matches) > 0 { + values = matches[0] + } + } + } + if len(values) == 0 { + return "" + } + return values[1] +} + +var findFunction = regexp.MustCompile(`(?ims)^function[^{]+`) +var findFunctionAll = regexp.MustCompile(`(?is)function[^{]+`) + +// 查找所有方法位置 +func findJSFunctionIndex(data string, all bool) [][2]int { + findFunction := findFunction + if all { + findFunction = findFunctionAll + } + + indexs := findFunction.FindAllStringIndex(data, -1) + fIndexs := make([][2]int, 0, len(indexs)) + + for _, index := range indexs { + if len(index) != 2 { + continue + } + count, data := 0, data[index[1]:] + for ii, v := range data { + if v == ' ' && count == 0 { + continue + } + if v == '{' { + count++ + } + + if v == '}' { + count-- + } + if count == 0 { + fIndexs = append(fIndexs, [2]int{index[0], index[1] + ii + 1}) + break + } + } + } + return fIndexs +} + +// 删除JS全局方法 +func removeJSGlobalFunction(html string) string { + indexs := findJSFunctionIndex(html, false) + block := make([]string, len(indexs)) + for i, next := len(indexs)-1, len(html); i >= 0; i-- { + index := indexs[i] + block[i] = html[index[1]:next] + next = index[0] + } + return strings.Join(block, "") +} + +// 根据名称获取方法 +func getJSFunctionByName(html string, name string) (string, error) { + indexs := findJSFunctionIndex(html, true) + for _, index := range indexs { + data := html[index[0]:index[1]] + if regexp.MustCompile(`function\s+` + name + `[()\s]+{`).MatchString(data) { + return data, nil + } + } + return "", fmt.Errorf("not find %s function", name) +} + +// 解析html中的JSON,选择最长的数据 +func htmlJsonToMap2(html string) (map[string]string, error) { + datas := findDataReg.FindAllStringSubmatch(html, -1) + var sData string + for _, data := range datas { + if len(datas) > 0 && len(data[1]) > len(sData) { + sData = data[1] + } + } + if sData == "" { + return nil, fmt.Errorf("not find data") + } + return jsonToMap(sData, html), nil +} + +// 解析html中的JSON +func htmlJsonToMap(html string) (map[string]string, error) { + datas := findDataReg.FindStringSubmatch(html) + if len(datas) != 2 { + return nil, fmt.Errorf("not find data") + } + return jsonToMap(datas[1], html), nil +} + +func jsonToMap(data, html string) map[string]string { + var param = make(map[string]string) + kvs := findKVReg.FindAllStringSubmatch(data, -1) + for _, kv := range kvs { + k, v := kv[1], kv[3] + if v == "" || strings.Contains(kv[2], "'") || IsNumber(kv[2]) { + param[k] = v + } else { + param[k] = findJSVarFunc(v, html) + } + } + return param +} + +func IsNumber(str string) bool { + for _, s := range str { + if !unicode.IsDigit(s) { + return false + } + } + return true +} + +var findFromReg = regexp.MustCompile(`data : '(.+?)'`) // 查找from字符串 + +// 解析html中的form +func htmlFormToMap(html string) (map[string]string, error) { + forms := findFromReg.FindStringSubmatch(html) + if len(forms) != 2 { + return nil, fmt.Errorf("not find file sgin") + } + return formToMap(forms[1]), nil +} + +func formToMap(from string) map[string]string { + var param = make(map[string]string) + for _, kv := range strings.Split(from, "&") { + kv := strings.SplitN(kv, "=", 2)[:2] + param[kv[0]] = kv[1] + } + return param +} + +var regExpirationTime = regexp.MustCompile(`e=(\d+)`) + +func GetExpirationTime(url string) (etime time.Duration) { + exps := regExpirationTime.FindStringSubmatch(url) + if len(exps) < 2 { + return + } + timestamp, err := strconv.ParseInt(exps[1], 10, 64) + if err != nil { + return + } + etime = time.Duration(timestamp-time.Now().Unix()) * time.Second + return +} + +func CookieToString(cookies []*http.Cookie) string { + if cookies == nil { + return "" + } + cookieStrings := make([]string, len(cookies)) + for i, cookie := range cookies { + cookieStrings[i] = cookie.Name + "=" + cookie.Value + } + return strings.Join(cookieStrings, ";") +} diff --git a/drivers/lanzou/meta.go b/drivers/lanzou/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..1e8826cadebcfad476cfba2269b5cd4a8122f721 --- /dev/null +++ b/drivers/lanzou/meta.go @@ -0,0 +1,42 @@ +package lanzou + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + Type string `json:"type" type:"select" options:"account,cookie,url" default:"cookie"` + + Account string `json:"account"` + Password string `json:"password"` + + Cookie string `json:"cookie" help:"about 15 days valid, ignore if shareUrl is used"` + + driver.RootID + SharePassword string `json:"share_password"` + BaseUrl string `json:"baseUrl" required:"true" default:"https://pc.woozooo.com" help:"basic URL for file operation"` + ShareUrl string `json:"shareUrl" required:"true" default:"https://pan.lanzoui.com" help:"used to get the sharing page"` + UserAgent string `json:"user_agent" required:"true" default:"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.39 (KHTML, like Gecko) Chrome/89.0.4389.111 Safari/537.39"` + RepairFileInfo bool `json:"repair_file_info" help:"To use webdav, you need to enable it"` +} + +func (a *Addition) IsCookie() bool { + return a.Type == "cookie" +} + +func (a *Addition) IsAccount() bool { + return a.Type == "account" +} + +var config = driver.Config{ + Name: "Lanzou", + LocalSort: true, + DefaultRoot: "-1", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &LanZou{} + }) +} diff --git a/drivers/lanzou/types.go b/drivers/lanzou/types.go new file mode 100644 index 0000000000000000000000000000000000000000..d03838ddf7ac708d608493a837c242669704121f --- /dev/null +++ b/drivers/lanzou/types.go @@ -0,0 +1,186 @@ +package lanzou + +import ( + "errors" + "fmt" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "time" +) + +var ErrFileShareCancel = errors.New("file sharing cancellation") +var ErrFileNotExist = errors.New("file does not exist") +var ErrCookieExpiration = errors.New("cookie expiration") + +type RespText[T any] struct { + Text T `json:"text"` +} + +type RespInfo[T any] struct { + Info T `json:"info"` +} + +var _ model.Obj = (*FileOrFolder)(nil) +var _ model.Obj = (*FileOrFolderByShareUrl)(nil) + +type FileOrFolder struct { + Name string `json:"name"` + //Onof string `json:"onof"` // 是否存在提取码 + //IsLock string `json:"is_lock"` + //IsCopyright int `json:"is_copyright"` + + // 文件通用 + ID string `json:"id"` + NameAll string `json:"name_all"` + Size string `json:"size"` + Time string `json:"time"` + //Icon string `json:"icon"` + //Downs string `json:"downs"` + //Filelock string `json:"filelock"` + //IsBakdownload int `json:"is_bakdownload"` + //Bakdownload string `json:"bakdownload"` + //IsDes int `json:"is_des"` // 是否存在描述 + //IsIco int `json:"is_ico"` + + // 文件夹 + FolID string `json:"fol_id"` + //Folderlock string `json:"folderlock"` + //FolderDes string `json:"folder_des"` + + // 缓存字段 + size *int64 `json:"-"` + time *time.Time `json:"-"` + repairFlag bool `json:"-"` + shareInfo *FileShare `json:"-"` +} + +func (f *FileOrFolder) CreateTime() time.Time { + return f.ModTime() +} + +func (f *FileOrFolder) GetHash() utils.HashInfo { + return utils.HashInfo{} +} + +func (f *FileOrFolder) GetID() string { + if f.IsDir() { + return f.FolID + } + return f.ID +} +func (f *FileOrFolder) GetName() string { + if f.IsDir() { + return f.Name + } + return f.NameAll +} +func (f *FileOrFolder) GetPath() string { return "" } +func (f *FileOrFolder) GetSize() int64 { + if f.size == nil { + size := SizeStrToInt64(f.Size) + f.size = &size + } + return *f.size +} +func (f *FileOrFolder) IsDir() bool { return f.FolID != "" } +func (f *FileOrFolder) ModTime() time.Time { + if f.time == nil { + time := MustParseTime(f.Time) + f.time = &time + } + return *f.time +} + +func (f *FileOrFolder) SetShareInfo(fs *FileShare) { + f.shareInfo = fs +} +func (f *FileOrFolder) GetShareInfo() *FileShare { + return f.shareInfo +} + +/* 通过ID获取文件/文件夹分享信息 */ +type FileShare struct { + Pwd string `json:"pwd"` + Onof string `json:"onof"` + Taoc string `json:"taoc"` + IsNewd string `json:"is_newd"` + + // 文件 + FID string `json:"f_id"` + + // 文件夹 + NewUrl string `json:"new_url"` + Name string `json:"name"` + Des string `json:"des"` +} + +/* 分享类型为文件夹 */ +type FileOrFolderByShareUrlResp struct { + Text []FileOrFolderByShareUrl `json:"text"` +} +type FileOrFolderByShareUrl struct { + ID string `json:"id"` + NameAll string `json:"name_all"` + + // 文件特有 + Duan string `json:"duan"` + Size string `json:"size"` + Time string `json:"time"` + //Icon string `json:"icon"` + //PIco int `json:"p_ico"` + //T int `json:"t"` + + // 文件夹特有 + IsFloder bool `json:"-"` + + // + Url string `json:"-"` + Pwd string `json:"-"` + + // 缓存字段 + size *int64 `json:"-"` + time *time.Time `json:"-"` + repairFlag bool `json:"-"` +} + +func (f *FileOrFolderByShareUrl) CreateTime() time.Time { + return f.ModTime() +} + +func (f *FileOrFolderByShareUrl) GetHash() utils.HashInfo { + return utils.HashInfo{} +} + +func (f *FileOrFolderByShareUrl) GetID() string { return f.ID } +func (f *FileOrFolderByShareUrl) GetName() string { return f.NameAll } +func (f *FileOrFolderByShareUrl) GetPath() string { return "" } +func (f *FileOrFolderByShareUrl) GetSize() int64 { + if f.size == nil { + size := SizeStrToInt64(f.Size) + f.size = &size + } + return *f.size +} +func (f *FileOrFolderByShareUrl) IsDir() bool { return f.IsFloder } +func (f *FileOrFolderByShareUrl) ModTime() time.Time { + if f.time == nil { + time := MustParseTime(f.Time) + f.time = &time + } + return *f.time +} + +// 获取下载链接的响应 +type FileShareInfoAndUrlResp[T string | int] struct { + Dom string `json:"dom"` + URL string `json:"url"` + Inf T `json:"inf"` +} + +func (u *FileShareInfoAndUrlResp[T]) GetBaseUrl() string { + return fmt.Sprint(u.Dom, "/file") +} + +func (u *FileShareInfoAndUrlResp[T]) GetDownloadUrl() string { + return fmt.Sprint(u.GetBaseUrl(), "/", u.URL) +} diff --git a/drivers/lanzou/util.go b/drivers/lanzou/util.go new file mode 100644 index 0000000000000000000000000000000000000000..4b9959ad53d162a3ab55135c6effe594eec8db88 --- /dev/null +++ b/drivers/lanzou/util.go @@ -0,0 +1,550 @@ +package lanzou + +import ( + "errors" + "fmt" + "net/http" + "regexp" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +var upClient *resty.Client +var once sync.Once + +func (d *LanZou) doupload(callback base.ReqCallback, resp interface{}) ([]byte, error) { + return d.post(d.BaseUrl+"/doupload.php", func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "uid": d.uid, + "vei": d.vei, + }) + if callback != nil { + callback(req) + } + }, resp) +} + +func (d *LanZou) get(url string, callback base.ReqCallback) ([]byte, error) { + return d.request(url, http.MethodGet, callback, false) +} + +func (d *LanZou) post(url string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + data, err := d._post(url, callback, resp, false) + if err == ErrCookieExpiration && d.IsAccount() { + if atomic.CompareAndSwapInt32(&d.flag, 0, 1) { + _, err2 := d.Login() + atomic.SwapInt32(&d.flag, 0) + if err2 != nil { + err = errors.Join(err, err2) + d.Status = err.Error() + op.MustSaveDriverStorage(d) + return data, err + } + } + for atomic.LoadInt32(&d.flag) != 0 { + runtime.Gosched() + } + return d._post(url, callback, resp, false) + } + return data, err +} + +func (d *LanZou) _post(url string, callback base.ReqCallback, resp interface{}, up bool) ([]byte, error) { + data, err := d.request(url, http.MethodPost, func(req *resty.Request) { + req.AddRetryCondition(func(r *resty.Response, err error) bool { + if utils.Json.Get(r.Body(), "zt").ToInt() == 4 { + time.Sleep(time.Second) + return true + } + return false + }) + if callback != nil { + callback(req) + } + }, up) + if err != nil { + return data, err + } + switch utils.Json.Get(data, "zt").ToInt() { + case 1, 2, 4: + if resp != nil { + // 返回类型不统一,忽略错误 + utils.Json.Unmarshal(data, resp) + } + return data, nil + case 9: // 登录过期 + return data, ErrCookieExpiration + default: + info := utils.Json.Get(data, "inf").ToString() + if info == "" { + info = utils.Json.Get(data, "info").ToString() + } + return data, fmt.Errorf(info) + } +} + +func (d *LanZou) request(url string, method string, callback base.ReqCallback, up bool) ([]byte, error) { + var req *resty.Request + if up { + once.Do(func() { + upClient = base.NewRestyClient().SetTimeout(120 * time.Second) + }) + req = upClient.R() + } else { + req = base.RestyClient.R() + } + + req.SetHeaders(map[string]string{ + "Referer": "https://pc.woozooo.com", + "User-Agent": d.UserAgent, + }) + + if d.Cookie != "" { + req.SetHeader("cookie", d.Cookie) + } + + if callback != nil { + callback(req) + } + + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + log.Debugf("lanzou request: url=>%s ,stats=>%d ,body => %s\n", res.Request.URL, res.StatusCode(), res.String()) + return res.Body(), err +} + +func (d *LanZou) Login() ([]*http.Cookie, error) { + resp, err := base.NewRestyClient().SetRedirectPolicy(resty.NoRedirectPolicy()). + R().SetFormData(map[string]string{ + "task": "3", + "uid": d.Account, + "pwd": d.Password, + "setSessionId": "", + "setSig": "", + "setScene": "", + "setTocen": "", + "formhash": "", + }).Post("https://up.woozooo.com/mlogin.php") + if err != nil { + return nil, err + } + if utils.Json.Get(resp.Body(), "zt").ToInt() != 1 { + return nil, fmt.Errorf("login err: %s", resp.Body()) + } + d.Cookie = CookieToString(resp.Cookies()) + return resp.Cookies(), nil +} + +/* +通过cookie获取数据 +*/ + +// 获取文件和文件夹,获取到的文件大小、更改时间不可信 +func (d *LanZou) GetAllFiles(folderID string) ([]model.Obj, error) { + folders, err := d.GetFolders(folderID) + if err != nil { + return nil, err + } + files, err := d.GetFiles(folderID) + if err != nil { + return nil, err + } + return append( + utils.MustSliceConvert(folders, func(folder FileOrFolder) model.Obj { + return &folder + }), utils.MustSliceConvert(files, func(file FileOrFolder) model.Obj { + return &file + })..., + ), nil +} + +// 通过ID获取文件夹 +func (d *LanZou) GetFolders(folderID string) ([]FileOrFolder, error) { + var resp RespText[[]FileOrFolder] + _, err := d.doupload(func(req *resty.Request) { + req.SetFormData(map[string]string{ + "task": "47", + "folder_id": folderID, + }) + }, &resp) + if err != nil { + return nil, err + } + return resp.Text, nil +} + +// 通过ID获取文件 +func (d *LanZou) GetFiles(folderID string) ([]FileOrFolder, error) { + files := make([]FileOrFolder, 0) + for pg := 1; ; pg++ { + var resp RespText[[]FileOrFolder] + _, err := d.doupload(func(req *resty.Request) { + req.SetFormData(map[string]string{ + "task": "5", + "folder_id": folderID, + "pg": strconv.Itoa(pg), + }) + }, &resp) + if err != nil { + return nil, err + } + if len(resp.Text) == 0 { + break + } + files = append(files, resp.Text...) + } + return files, nil +} + +// 通过ID获取文件夹分享地址 +func (d *LanZou) getFolderShareUrlByID(fileID string) (*FileShare, error) { + var resp RespInfo[FileShare] + _, err := d.doupload(func(req *resty.Request) { + req.SetFormData(map[string]string{ + "task": "18", + "file_id": fileID, + }) + }, &resp) + if err != nil { + return nil, err + } + return &resp.Info, nil +} + +// 通过ID获取文件分享地址 +func (d *LanZou) getFileShareUrlByID(fileID string) (*FileShare, error) { + var resp RespInfo[FileShare] + _, err := d.doupload(func(req *resty.Request) { + req.SetFormData(map[string]string{ + "task": "22", + "file_id": fileID, + }) + }, &resp) + if err != nil { + return nil, err + } + return &resp.Info, nil +} + +/* +通过分享链接获取数据 +*/ + +// 判断类容 +var isFileReg = regexp.MustCompile(`class="fileinfo"|id="file"|文件描述`) +var isFolderReg = regexp.MustCompile(`id="infos"`) + +// 获取文件文件夹基础信息 + +// 获取文件名称 +var nameFindReg = regexp.MustCompile(`(.+?) - 蓝奏云|id="filenajax">(.+?)|var filename = '(.+?)';|
([^<>]+?)
`) + +// 获取文件大小 +var sizeFindReg = regexp.MustCompile(`(?i)大小\W*([0-9.]+\s*[bkm]+)`) + +// 获取文件时间 +var timeFindReg = regexp.MustCompile(`\d+\s*[秒天分小][钟时]?前|[昨前]天|\d{4}-\d{2}-\d{2}`) + +// 查找分享文件夹子文件夹ID和名称 +var findSubFolderReg = regexp.MustCompile(`(?i)(?:folderlink|mbxfolder).+href="/(.+?)"(?:.+filename")?>(.+?)<`) + +// 获取下载页面链接 +var findDownPageParamReg = regexp.MustCompile(` acw_sc__v2 validation error ,data => %s\n", firstPageDataStr) + return "", err + } + continue + } + return firstPageDataStr, nil + } + return "", errors.New("acw_sc__v2 validation error") +} + +// 通过分享链接获取文件或文件夹 +func (d *LanZou) GetFileOrFolderByShareUrl(shareID, pwd string) ([]model.Obj, error) { + pageData, err := d.getShareUrlHtml(shareID) + if err != nil { + return nil, err + } + + if !isFileReg.MatchString(pageData) { + files, err := d.getFolderByShareUrl(pwd, pageData) + if err != nil { + return nil, err + } + return utils.MustSliceConvert(files, func(file FileOrFolderByShareUrl) model.Obj { + return &file + }), nil + } else { + file, err := d.getFilesByShareUrl(shareID, pwd, pageData) + if err != nil { + return nil, err + } + return []model.Obj{file}, nil + } +} + +// 通过分享链接获取文件(下载链接也使用此方法) +// FileOrFolderByShareUrl 包含 pwd 和 url 字段 +// 参考 https://github.com/zaxtyson/LanZouCloud-API/blob/ab2e9ec715d1919bf432210fc16b91c6775fbb99/lanzou/api/core.py#L440 +func (d *LanZou) GetFilesByShareUrl(shareID, pwd string) (file *FileOrFolderByShareUrl, err error) { + pageData, err := d.getShareUrlHtml(shareID) + if err != nil { + return nil, err + } + return d.getFilesByShareUrl(shareID, pwd, pageData) +} + +func (d *LanZou) getFilesByShareUrl(shareID, pwd string, sharePageData string) (*FileOrFolderByShareUrl, error) { + var ( + param map[string]string + downloadUrl string + baseUrl string + file FileOrFolderByShareUrl + ) + + // 需要密码 + if strings.Contains(sharePageData, "pwdload") || strings.Contains(sharePageData, "passwddiv") { + sharePageData, err := getJSFunctionByName(sharePageData, "down_p") + if err != nil { + return nil, err + } + param, err := htmlJsonToMap(sharePageData) + if err != nil { + return nil, err + } + param["p"] = pwd + + fileIDs := findFileIDReg.FindStringSubmatch(sharePageData) + var fileID string + if len(fileIDs) > 1 { + fileID = fileIDs[1] + } else { + return nil, fmt.Errorf("not find file id") + } + var resp FileShareInfoAndUrlResp[string] + _, err = d.post(d.ShareUrl+"/ajaxm.php?file="+fileID, func(req *resty.Request) { req.SetFormData(param) }, &resp) + if err != nil { + return nil, err + } + file.NameAll = resp.Inf + file.Pwd = pwd + baseUrl = resp.GetBaseUrl() + downloadUrl = resp.GetDownloadUrl() + } else { + urlpaths := findDownPageParamReg.FindStringSubmatch(sharePageData) + if len(urlpaths) != 2 { + log.Errorf("lanzou: err => not find file page param ,data => %s\n", sharePageData) + return nil, fmt.Errorf("not find file page param") + } + data, err := d.get(fmt.Sprint(d.ShareUrl, urlpaths[1]), nil) + if err != nil { + return nil, err + } + nextPageData := RemoveNotes(string(data)) + param, err = htmlJsonToMap(nextPageData) + if err != nil { + return nil, err + } + + fileIDs := findFileIDReg.FindStringSubmatch(nextPageData) + var fileID string + if len(fileIDs) > 1 { + fileID = fileIDs[1] + } else { + return nil, fmt.Errorf("not find file id") + } + var resp FileShareInfoAndUrlResp[int] + _, err = d.post(d.ShareUrl+"/ajaxm.php?file="+fileID, func(req *resty.Request) { req.SetFormData(param) }, &resp) + if err != nil { + return nil, err + } + baseUrl = resp.GetBaseUrl() + downloadUrl = resp.GetDownloadUrl() + + names := nameFindReg.FindStringSubmatch(sharePageData) + if len(names) > 1 { + for _, name := range names[1:] { + if name != "" { + file.NameAll = name + break + } + } + } + } + + sizes := sizeFindReg.FindStringSubmatch(sharePageData) + if len(sizes) == 2 { + file.Size = sizes[1] + } + file.ID = shareID + file.Time = timeFindReg.FindString(sharePageData) + + // 重定向获取真实链接 + res, err := base.NoRedirectClient.R().SetHeaders(map[string]string{ + "accept-language": "zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6", + }).Get(downloadUrl) + if err != nil { + return nil, err + } + + file.Url = res.Header().Get("location") + + // 触发验证 + rPageData := res.String() + if res.StatusCode() != 302 { + param, err = htmlJsonToMap(rPageData) + if err != nil { + return nil, err + } + param["el"] = "2" + time.Sleep(time.Second * 2) + + // 通过验证获取直连 + data, err := d.post(fmt.Sprint(baseUrl, "/ajax.php"), func(req *resty.Request) { req.SetFormData(param) }, nil) + if err != nil { + return nil, err + } + file.Url = utils.Json.Get(data, "url").ToString() + } + return &file, nil +} + +// 通过分享链接获取文件夹 +// 似乎子目录和文件不会加密 +// 参考 https://github.com/zaxtyson/LanZouCloud-API/blob/ab2e9ec715d1919bf432210fc16b91c6775fbb99/lanzou/api/core.py#L1089 +func (d *LanZou) GetFolderByShareUrl(shareID, pwd string) ([]FileOrFolderByShareUrl, error) { + pageData, err := d.getShareUrlHtml(shareID) + if err != nil { + return nil, err + } + return d.getFolderByShareUrl(pwd, pageData) +} + +func (d *LanZou) getFolderByShareUrl(pwd string, sharePageData string) ([]FileOrFolderByShareUrl, error) { + from, err := htmlJsonToMap(sharePageData) + if err != nil { + return nil, err + } + + files := make([]FileOrFolderByShareUrl, 0) + // vip获取文件夹 + floders := findSubFolderReg.FindAllStringSubmatch(sharePageData, -1) + for _, floder := range floders { + if len(floder) == 3 { + files = append(files, FileOrFolderByShareUrl{ + // Pwd: pwd, // 子文件夹不加密 + ID: floder[1], + NameAll: floder[2], + IsFloder: true, + }) + } + } + + // 获取文件 + from["pwd"] = pwd + for page := 1; ; page++ { + from["pg"] = strconv.Itoa(page) + var resp FileOrFolderByShareUrlResp + _, err := d.post(d.ShareUrl+"/filemoreajax.php", func(req *resty.Request) { req.SetFormData(from) }, &resp) + if err != nil { + return nil, err + } + // 文件夹中的文件加密 + for i := 0; i < len(resp.Text); i++ { + resp.Text[i].Pwd = pwd + } + if len(resp.Text) == 0 { + break + } + files = append(files, resp.Text...) + time.Sleep(time.Second) + } + return files, nil +} + +// 通过下载头获取真实文件信息 +func (d *LanZou) getFileRealInfo(downURL string) (*int64, *time.Time) { + res, _ := base.RestyClient.R().Head(downURL) + if res == nil { + return nil, nil + } + time, _ := http.ParseTime(res.Header().Get("Last-Modified")) + size, _ := strconv.ParseInt(res.Header().Get("Content-Length"), 10, 64) + return &size, &time +} + +func (d *LanZou) getVeiAndUid() (vei string, uid string, err error) { + var resp []byte + resp, err = d.get("https://pc.woozooo.com/mydisk.php", func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "item": "files", + "action": "index", + }) + }) + if err != nil { + return + } + // uid + uids := regexp.MustCompile(`uid=([^'"&;]+)`).FindStringSubmatch(string(resp)) + if len(uids) < 2 { + err = fmt.Errorf("uid variable not find") + return + } + uid = uids[1] + + // vei + html := RemoveNotes(string(resp)) + data, err := htmlJsonToMap(html) + if err != nil { + return + } + vei = data["vei"] + + return +} diff --git a/drivers/lark.go b/drivers/lark.go new file mode 100644 index 0000000000000000000000000000000000000000..d5070078651600be756f1cbf39ec4013adb984c8 --- /dev/null +++ b/drivers/lark.go @@ -0,0 +1,8 @@ +// +build linux darwin windows +// +build amd64 arm64 + +package drivers + +import ( + _ "github.com/alist-org/alist/v3/drivers/lark" +) diff --git a/drivers/lark/driver.go b/drivers/lark/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..d267230044486f63b0f8cd1b5dc505b09b17a951 --- /dev/null +++ b/drivers/lark/driver.go @@ -0,0 +1,397 @@ +package lark + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + lark "github.com/larksuite/oapi-sdk-go/v3" + larkcore "github.com/larksuite/oapi-sdk-go/v3/core" + larkdrive "github.com/larksuite/oapi-sdk-go/v3/service/drive/v1" + "golang.org/x/time/rate" +) + +type Lark struct { + model.Storage + Addition + + client *lark.Client + rootFolderToken string +} + +func (c *Lark) Config() driver.Config { + return config +} + +func (c *Lark) GetAddition() driver.Additional { + return &c.Addition +} + +func (c *Lark) Init(ctx context.Context) error { + c.client = lark.NewClient(c.AppId, c.AppSecret, lark.WithTokenCache(newTokenCache())) + + paths := strings.Split(c.RootFolderPath, "/") + token := "" + + var ok bool + var file *larkdrive.File + for _, p := range paths { + if p == "" { + token = "" + continue + } + + resp, err := c.client.Drive.File.ListByIterator(ctx, larkdrive.NewListFileReqBuilder().FolderToken(token).Build()) + if err != nil { + return err + } + + for { + ok, file, err = resp.Next() + if !ok { + return errs.ObjectNotFound + } + + if err != nil { + return err + } + + if *file.Type == "folder" && *file.Name == p { + token = *file.Token + break + } + } + } + + c.rootFolderToken = token + + return nil +} + +func (c *Lark) Drop(ctx context.Context) error { + return nil +} + +func (c *Lark) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + token, ok := c.getObjToken(ctx, dir.GetPath()) + if !ok { + return nil, errs.ObjectNotFound + } + + if token == emptyFolderToken { + return nil, nil + } + + resp, err := c.client.Drive.File.ListByIterator(ctx, larkdrive.NewListFileReqBuilder().FolderToken(token).Build()) + if err != nil { + return nil, err + } + + ok = false + var file *larkdrive.File + var res []model.Obj + + for { + ok, file, err = resp.Next() + if !ok { + break + } + + if err != nil { + return nil, err + } + + modifiedUnix, _ := strconv.ParseInt(*file.ModifiedTime, 10, 64) + createdUnix, _ := strconv.ParseInt(*file.CreatedTime, 10, 64) + + f := model.Object{ + ID: *file.Token, + Path: strings.Join([]string{c.RootFolderPath, dir.GetPath(), *file.Name}, "/"), + Name: *file.Name, + Size: 0, + Modified: time.Unix(modifiedUnix, 0), + Ctime: time.Unix(createdUnix, 0), + IsFolder: *file.Type == "folder", + } + res = append(res, &f) + } + + return res, nil +} + +func (c *Lark) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + token, ok := c.getObjToken(ctx, file.GetPath()) + if !ok { + return nil, errs.ObjectNotFound + } + + resp, err := c.client.GetTenantAccessTokenBySelfBuiltApp(ctx, &larkcore.SelfBuiltTenantAccessTokenReq{ + AppID: c.AppId, + AppSecret: c.AppSecret, + }) + + if err != nil { + return nil, err + } + + if !c.ExternalMode { + accessToken := resp.TenantAccessToken + + url := fmt.Sprintf("https://open.feishu.cn/open-apis/drive/v1/files/%s/download", token) + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + req.Header.Set("Range", "bytes=0-1") + + ar, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + + if ar.StatusCode != http.StatusPartialContent { + return nil, errors.New("failed to get download link") + } + + return &model.Link{ + URL: url, + Header: http.Header{ + "Authorization": []string{fmt.Sprintf("Bearer %s", accessToken)}, + }, + }, nil + } else { + url := strings.Join([]string{c.TenantUrlPrefix, "file", token}, "/") + + return &model.Link{ + URL: url, + }, nil + } +} + +func (c *Lark) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + token, ok := c.getObjToken(ctx, parentDir.GetPath()) + if !ok { + return nil, errs.ObjectNotFound + } + + body, err := larkdrive.NewCreateFolderFilePathReqBodyBuilder().FolderToken(token).Name(dirName).Build() + if err != nil { + return nil, err + } + + resp, err := c.client.Drive.File.CreateFolder(ctx, + larkdrive.NewCreateFolderFileReqBuilder().Body(body).Build()) + if err != nil { + return nil, err + } + + if !resp.Success() { + return nil, errors.New(resp.Error()) + } + + return &model.Object{ + ID: *resp.Data.Token, + Path: strings.Join([]string{c.RootFolderPath, parentDir.GetPath(), dirName}, "/"), + Name: dirName, + Size: 0, + IsFolder: true, + }, nil +} + +func (c *Lark) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + srcToken, ok := c.getObjToken(ctx, srcObj.GetPath()) + if !ok { + return nil, errs.ObjectNotFound + } + + dstDirToken, ok := c.getObjToken(ctx, dstDir.GetPath()) + if !ok { + return nil, errs.ObjectNotFound + } + + req := larkdrive.NewMoveFileReqBuilder(). + Body(larkdrive.NewMoveFileReqBodyBuilder(). + Type("file"). + FolderToken(dstDirToken). + Build()).FileToken(srcToken). + Build() + + // 发起请求 + resp, err := c.client.Drive.File.Move(ctx, req) + if err != nil { + return nil, err + } + + if !resp.Success() { + return nil, errors.New(resp.Error()) + } + + return nil, nil +} + +func (c *Lark) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + // TODO rename obj, optional + return nil, errs.NotImplement +} + +func (c *Lark) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + srcToken, ok := c.getObjToken(ctx, srcObj.GetPath()) + if !ok { + return nil, errs.ObjectNotFound + } + + dstDirToken, ok := c.getObjToken(ctx, dstDir.GetPath()) + if !ok { + return nil, errs.ObjectNotFound + } + + req := larkdrive.NewCopyFileReqBuilder(). + Body(larkdrive.NewCopyFileReqBodyBuilder(). + Name(srcObj.GetName()). + Type("file"). + FolderToken(dstDirToken). + Build()).FileToken(srcToken). + Build() + + // 发起请求 + resp, err := c.client.Drive.File.Copy(ctx, req) + if err != nil { + return nil, err + } + + if !resp.Success() { + return nil, errors.New(resp.Error()) + } + + return nil, nil +} + +func (c *Lark) Remove(ctx context.Context, obj model.Obj) error { + token, ok := c.getObjToken(ctx, obj.GetPath()) + if !ok { + return errs.ObjectNotFound + } + + req := larkdrive.NewDeleteFileReqBuilder(). + FileToken(token). + Type("file"). + Build() + + // 发起请求 + resp, err := c.client.Drive.File.Delete(ctx, req) + if err != nil { + return err + } + + if !resp.Success() { + return errors.New(resp.Error()) + } + + return nil +} + +var uploadLimit = rate.NewLimiter(rate.Every(time.Second), 5) + +func (c *Lark) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + token, ok := c.getObjToken(ctx, dstDir.GetPath()) + if !ok { + return nil, errs.ObjectNotFound + } + + // prepare + req := larkdrive.NewUploadPrepareFileReqBuilder(). + FileUploadInfo(larkdrive.NewFileUploadInfoBuilder(). + FileName(stream.GetName()). + ParentType(`explorer`). + ParentNode(token). + Size(int(stream.GetSize())). + Build()). + Build() + + // 发起请求 + uploadLimit.Wait(ctx) + resp, err := c.client.Drive.File.UploadPrepare(ctx, req) + if err != nil { + return nil, err + } + + if !resp.Success() { + return nil, errors.New(resp.Error()) + } + + uploadId := *resp.Data.UploadId + blockSize := *resp.Data.BlockSize + blockCount := *resp.Data.BlockNum + + // upload + for i := 0; i < blockCount; i++ { + length := int64(blockSize) + if i == blockCount-1 { + length = stream.GetSize() - int64(i*blockSize) + } + + reader := io.LimitReader(stream, length) + + req := larkdrive.NewUploadPartFileReqBuilder(). + Body(larkdrive.NewUploadPartFileReqBodyBuilder(). + UploadId(uploadId). + Seq(i). + Size(int(length)). + File(reader). + Build()). + Build() + + // 发起请求 + uploadLimit.Wait(ctx) + resp, err := c.client.Drive.File.UploadPart(ctx, req) + + if err != nil { + return nil, err + } + + if !resp.Success() { + return nil, errors.New(resp.Error()) + } + + up(float64(i) / float64(blockCount)) + } + + //close + closeReq := larkdrive.NewUploadFinishFileReqBuilder(). + Body(larkdrive.NewUploadFinishFileReqBodyBuilder(). + UploadId(uploadId). + BlockNum(blockCount). + Build()). + Build() + + // 发起请求 + closeResp, err := c.client.Drive.File.UploadFinish(ctx, closeReq) + if err != nil { + return nil, err + } + + if !closeResp.Success() { + return nil, errors.New(closeResp.Error()) + } + + return &model.Object{ + ID: *closeResp.Data.FileToken, + }, nil +} + +//func (d *Lark) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*Lark)(nil) diff --git a/drivers/lark/meta.go b/drivers/lark/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..221345e222ca224b20051955479d8665e2886f50 --- /dev/null +++ b/drivers/lark/meta.go @@ -0,0 +1,36 @@ +package lark + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + // Usually one of two + driver.RootPath + // define other + AppId string `json:"app_id" type:"text" help:"app id"` + AppSecret string `json:"app_secret" type:"text" help:"app secret"` + ExternalMode bool `json:"external_mode" type:"bool" help:"external mode"` + TenantUrlPrefix string `json:"tenant_url_prefix" type:"text" help:"tenant url prefix"` +} + +var config = driver.Config{ + Name: "Lark", + LocalSort: false, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: false, + NeedMs: false, + DefaultRoot: "/", + CheckStatus: false, + Alert: "", + NoOverwriteUpload: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Lark{} + }) +} diff --git a/drivers/lark/types.go b/drivers/lark/types.go new file mode 100644 index 0000000000000000000000000000000000000000..3ebefd556dc859b10cd472b4c33498fc6f1f5d8f --- /dev/null +++ b/drivers/lark/types.go @@ -0,0 +1,32 @@ +package lark + +import ( + "context" + "github.com/Xhofe/go-cache" + "time" +) + +type TokenCache struct { + cache.ICache[string] +} + +func (t *TokenCache) Set(_ context.Context, key string, value string, expireTime time.Duration) error { + t.ICache.Set(key, value, cache.WithEx[string](expireTime)) + + return nil +} + +func (t *TokenCache) Get(_ context.Context, key string) (string, error) { + v, ok := t.ICache.Get(key) + if ok { + return v, nil + } + + return "", nil +} + +func newTokenCache() *TokenCache { + c := cache.NewMemCache[string]() + + return &TokenCache{c} +} diff --git a/drivers/lark/util.go b/drivers/lark/util.go new file mode 100644 index 0000000000000000000000000000000000000000..8c6828bd17656d18db5d3b43a4b5a35c8b6a7c82 --- /dev/null +++ b/drivers/lark/util.go @@ -0,0 +1,66 @@ +package lark + +import ( + "context" + "github.com/Xhofe/go-cache" + larkdrive "github.com/larksuite/oapi-sdk-go/v3/service/drive/v1" + log "github.com/sirupsen/logrus" + "path" + "time" +) + +const objTokenCacheDuration = 5 * time.Minute +const emptyFolderToken = "empty" + +var objTokenCache = cache.NewMemCache[string]() +var exOpts = cache.WithEx[string](objTokenCacheDuration) + +func (c *Lark) getObjToken(ctx context.Context, folderPath string) (string, bool) { + if token, ok := objTokenCache.Get(folderPath); ok { + return token, true + } + + dir, name := path.Split(folderPath) + // strip the last slash of dir if it exists + if len(dir) > 0 && dir[len(dir)-1] == '/' { + dir = dir[:len(dir)-1] + } + if name == "" { + return c.rootFolderToken, true + } + + var parentToken string + var found bool + parentToken, found = c.getObjToken(ctx, dir) + if !found { + return emptyFolderToken, false + } + + req := larkdrive.NewListFileReqBuilder().FolderToken(parentToken).Build() + resp, err := c.client.Drive.File.ListByIterator(ctx, req) + + if err != nil { + log.WithError(err).Error("failed to list files") + return emptyFolderToken, false + } + + var file *larkdrive.File + for { + found, file, err = resp.Next() + if !found { + break + } + + if err != nil { + log.WithError(err).Error("failed to get next file") + break + } + + if *file.Name == name { + objTokenCache.Set(folderPath, *file.Token, exOpts) + return *file.Token, true + } + } + + return emptyFolderToken, false +} diff --git a/drivers/lenovonas_share/driver.go b/drivers/lenovonas_share/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..12e8514325f95a7ca88bfa61df63722f47adab55 --- /dev/null +++ b/drivers/lenovonas_share/driver.go @@ -0,0 +1,121 @@ +package LenovoNasShare + +import ( + "context" + "net/http" + + "github.com/go-resty/resty/v2" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" +) + +type LenovoNasShare struct { + model.Storage + Addition + stoken string +} + +func (d *LenovoNasShare) Config() driver.Config { + return config +} + +func (d *LenovoNasShare) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *LenovoNasShare) Init(ctx context.Context) error { + if d.Host == "" { + d.Host = "https://siot-share.lenovo.com.cn" + } + query := map[string]string{ + "code": d.ShareId, + "password": d.SharePwd, + } + resp, err := d.request(d.Host+"/oneproxy/api/share/v1/access", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, nil) + if err != nil { + return err + } + d.stoken = utils.Json.Get(resp, "data", "stoken").ToString() + return nil +} + +func (d *LenovoNasShare) Drop(ctx context.Context) error { + return nil +} + +func (d *LenovoNasShare) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files := make([]File, 0) + + var resp Files + query := map[string]string{ + "code": d.ShareId, + "num": "5000", + "stoken": d.stoken, + "path": dir.GetPath(), + } + _, err := d.request(d.Host+"/oneproxy/api/share/v1/files", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return nil, err + } + files = append(files, resp.Data.List...) + + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return src, nil + }) +} + +func (d *LenovoNasShare) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + query := map[string]string{ + "code": d.ShareId, + "stoken": d.stoken, + "path": file.GetPath(), + } + resp, err := d.request(d.Host+"/oneproxy/api/share/v1/file/link", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, nil) + if err != nil { + return nil, err + } + downloadUrl := d.Host + "/oneproxy/api/share/v1/file/download?code=" + d.ShareId + "&dtoken=" + utils.Json.Get(resp, "data", "param", "dtoken").ToString() + + link := model.Link{ + URL: downloadUrl, + Header: http.Header{ + "Referer": []string{"https://siot-share.lenovo.com.cn"}, + }, + } + return &link, nil +} + +func (d *LenovoNasShare) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + return nil, errs.NotImplement +} + +func (d *LenovoNasShare) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + return nil, errs.NotImplement +} + +func (d *LenovoNasShare) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + return nil, errs.NotImplement +} + +func (d *LenovoNasShare) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + return nil, errs.NotImplement +} + +func (d *LenovoNasShare) Remove(ctx context.Context, obj model.Obj) error { + return errs.NotImplement +} + +func (d *LenovoNasShare) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + return nil, errs.NotImplement +} + +var _ driver.Driver = (*LenovoNasShare)(nil) diff --git a/drivers/lenovonas_share/meta.go b/drivers/lenovonas_share/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..0bf80555739995dd7715dc85db61dea6e21eb89b --- /dev/null +++ b/drivers/lenovonas_share/meta.go @@ -0,0 +1,33 @@ +package LenovoNasShare + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + ShareId string `json:"share_id" required:"true" help:"The part after the last / in the shared link"` + SharePwd string `json:"share_pwd" required:"true" help:"The password of the shared link"` + Host string `json:"host" required:"true" default:"https://siot-share.lenovo.com.cn" help:"You can change it to your local area network"` +} + +var config = driver.Config{ + Name: "LenovoNasShare", + LocalSort: true, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: true, + NeedMs: false, + DefaultRoot: "", + CheckStatus: false, + Alert: "", + NoOverwriteUpload: false, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &LenovoNasShare{} + }) +} diff --git a/drivers/lenovonas_share/types.go b/drivers/lenovonas_share/types.go new file mode 100644 index 0000000000000000000000000000000000000000..77b966d3bee88609681801f0d9e4832de9ef3302 --- /dev/null +++ b/drivers/lenovonas_share/types.go @@ -0,0 +1,82 @@ +package LenovoNasShare + +import ( + "encoding/json" + "time" + + "github.com/alist-org/alist/v3/pkg/utils" + + _ "github.com/alist-org/alist/v3/internal/model" +) + +func (f *File) UnmarshalJSON(data []byte) error { + type Alias File + aux := &struct { + CreateAt int64 `json:"time"` + UpdateAt int64 `json:"chtime"` + *Alias + }{ + Alias: (*Alias)(f), + } + + if err := json.Unmarshal(data, aux); err != nil { + return err + } + + f.CreateAt = time.Unix(aux.CreateAt, 0) + f.UpdateAt = time.Unix(aux.UpdateAt, 0) + + return nil +} + +type File struct { + FileName string `json:"name"` + Size int64 `json:"size"` + CreateAt time.Time `json:"time"` + UpdateAt time.Time `json:"chtime"` + Path string `json:"path"` + Type string `json:"type"` +} + +func (f File) GetHash() utils.HashInfo { + return utils.HashInfo{} +} + +func (f File) GetPath() string { + return f.Path +} + +func (f File) GetSize() int64 { + return f.Size +} + +func (f File) GetName() string { + return f.FileName +} + +func (f File) ModTime() time.Time { + return f.UpdateAt +} + +func (f File) CreateTime() time.Time { + return f.CreateAt +} + +func (f File) IsDir() bool { + return f.Type == "dir" +} + +func (f File) GetID() string { + return f.GetPath() +} + +func (f File) Thumb() string { + return "" +} + +type Files struct { + Data struct { + List []File `json:"list"` + HasMore bool `json:"has_more"` + } `json:"data"` +} diff --git a/drivers/lenovonas_share/util.go b/drivers/lenovonas_share/util.go new file mode 100644 index 0000000000000000000000000000000000000000..ccf3af042a48a1fb423ed861302d09af7399030b --- /dev/null +++ b/drivers/lenovonas_share/util.go @@ -0,0 +1,36 @@ +package LenovoNasShare + +import ( + "errors" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/pkg/utils" + jsoniter "github.com/json-iterator/go" +) + +func (d *LenovoNasShare) request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "origin": "https://siot-share.lenovo.com.cn", + "referer": "https://siot-share.lenovo.com.cn/", + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) alist-client", + "platform": "web", + "app-version": "3", + }) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + body := res.Body() + result := utils.Json.Get(body, "result").ToBool() + if !result { + return nil, errors.New(jsoniter.Get(body, "error", "msg").ToString()) + } + return body, nil +} diff --git a/drivers/local/driver.go b/drivers/local/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..2519232e7d6e9d18065ea5d321c306b4f0bb23cb --- /dev/null +++ b/drivers/local/driver.go @@ -0,0 +1,331 @@ +package local + +import ( + "bytes" + "context" + "errors" + "fmt" + "io/fs" + "net/http" + "os" + stdpath "path" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/sign" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + "github.com/alist-org/times" + cp "github.com/otiai10/copy" + log "github.com/sirupsen/logrus" + _ "golang.org/x/image/webp" +) + +type Local struct { + model.Storage + Addition + mkdirPerm int32 + + // zero means no limit + thumbConcurrency int + thumbTokenBucket TokenBucket +} + +func (d *Local) Config() driver.Config { + return config +} + +func (d *Local) Init(ctx context.Context) error { + if d.MkdirPerm == "" { + d.mkdirPerm = 0777 + } else { + v, err := strconv.ParseUint(d.MkdirPerm, 8, 32) + if err != nil { + return err + } + d.mkdirPerm = int32(v) + } + if !utils.Exists(d.GetRootPath()) { + return fmt.Errorf("root folder %s not exists", d.GetRootPath()) + } + if !filepath.IsAbs(d.GetRootPath()) { + abs, err := filepath.Abs(d.GetRootPath()) + if err != nil { + return err + } + d.Addition.RootFolderPath = abs + } + if d.ThumbCacheFolder != "" && !utils.Exists(d.ThumbCacheFolder) { + err := os.MkdirAll(d.ThumbCacheFolder, os.FileMode(d.mkdirPerm)) + if err != nil { + return err + } + } + if d.ThumbConcurrency != "" { + v, err := strconv.ParseUint(d.ThumbConcurrency, 10, 32) + if err != nil { + return err + } + d.thumbConcurrency = int(v) + } + if d.thumbConcurrency == 0 { + d.thumbTokenBucket = NewNopTokenBucket() + } else { + d.thumbTokenBucket = NewStaticTokenBucketWithMigration(d.thumbTokenBucket, d.thumbConcurrency) + } + return nil +} + +func (d *Local) Drop(ctx context.Context) error { + return nil +} + +func (d *Local) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Local) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + fullPath := dir.GetPath() + rawFiles, err := readDir(fullPath) + if err != nil { + return nil, err + } + var files []model.Obj + for _, f := range rawFiles { + if !d.ShowHidden && strings.HasPrefix(f.Name(), ".") { + continue + } + file := d.FileInfoToObj(ctx, f, args.ReqPath, fullPath) + files = append(files, file) + } + return files, nil +} +func (d *Local) FileInfoToObj(ctx context.Context, f fs.FileInfo, reqPath string, fullPath string) model.Obj { + thumb := "" + if d.Thumbnail { + typeName := utils.GetFileType(f.Name()) + if typeName == conf.IMAGE || typeName == conf.VIDEO { + thumb = common.GetApiUrl(common.GetHttpReq(ctx)) + stdpath.Join("/d", reqPath, f.Name()) + thumb = utils.EncodePath(thumb, true) + thumb += "?type=thumb&sign=" + sign.Sign(stdpath.Join(reqPath, f.Name())) + } + } + isFolder := f.IsDir() || isSymlinkDir(f, fullPath) + var size int64 + if !isFolder { + size = f.Size() + } + var ctime time.Time + t, err := times.Stat(stdpath.Join(fullPath, f.Name())) + if err == nil { + if t.HasBirthTime() { + ctime = t.BirthTime() + } + } + + file := model.ObjThumb{ + Object: model.Object{ + Path: filepath.Join(fullPath, f.Name()), + Name: f.Name(), + Modified: f.ModTime(), + Size: size, + IsFolder: isFolder, + Ctime: ctime, + }, + Thumbnail: model.Thumbnail{ + Thumbnail: thumb, + }, + } + return &file +} +func (d *Local) GetMeta(ctx context.Context, path string) (model.Obj, error) { + f, err := os.Stat(path) + if err != nil { + return nil, err + } + file := d.FileInfoToObj(ctx, f, path, path) + //h := "123123" + //if s, ok := f.(model.SetHash); ok && file.GetHash() == ("","") { + // s.SetHash(h,"SHA1") + //} + return file, nil + +} + +func (d *Local) Get(ctx context.Context, path string) (model.Obj, error) { + path = filepath.Join(d.GetRootPath(), path) + f, err := os.Stat(path) + if err != nil { + if strings.Contains(err.Error(), "cannot find the file") { + return nil, errs.ObjectNotFound + } + return nil, err + } + isFolder := f.IsDir() || isSymlinkDir(f, path) + size := f.Size() + if isFolder { + size = 0 + } + var ctime time.Time + t, err := times.Stat(path) + if err == nil { + if t.HasBirthTime() { + ctime = t.BirthTime() + } + } + file := model.Object{ + Path: path, + Name: f.Name(), + Modified: f.ModTime(), + Ctime: ctime, + Size: size, + IsFolder: isFolder, + } + return &file, nil +} + +func (d *Local) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + fullPath := file.GetPath() + var link model.Link + if args.Type == "thumb" && utils.Ext(file.GetName()) != "svg" { + var buf *bytes.Buffer + var thumbPath *string + err := d.thumbTokenBucket.Do(ctx, func() error { + var err error + buf, thumbPath, err = d.getThumb(file) + return err + }) + if err != nil { + return nil, err + } + link.Header = http.Header{ + "Content-Type": []string{"image/png"}, + } + if thumbPath != nil { + open, err := os.Open(*thumbPath) + if err != nil { + return nil, err + } + link.MFile = open + } else { + link.MFile = model.NewNopMFile(bytes.NewReader(buf.Bytes())) + //link.Header.Set("Content-Length", strconv.Itoa(buf.Len())) + } + } else { + open, err := os.Open(fullPath) + if err != nil { + return nil, err + } + link.MFile = open + } + return &link, nil +} + +func (d *Local) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + fullPath := filepath.Join(parentDir.GetPath(), dirName) + err := os.MkdirAll(fullPath, os.FileMode(d.mkdirPerm)) + if err != nil { + return err + } + return nil +} + +func (d *Local) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + srcPath := srcObj.GetPath() + dstPath := filepath.Join(dstDir.GetPath(), srcObj.GetName()) + if utils.IsSubPath(srcPath, dstPath) { + return fmt.Errorf("the destination folder is a subfolder of the source folder") + } + if err := os.Rename(srcPath, dstPath); err != nil && strings.Contains(err.Error(), "invalid cross-device link") { + // Handle cross-device file move in local driver + if err = d.Copy(ctx, srcObj, dstDir); err != nil { + return err + } else { + // Directly remove file without check recycle bin if successfully copied + if srcObj.IsDir() { + err = os.RemoveAll(srcObj.GetPath()) + } else { + err = os.Remove(srcObj.GetPath()) + } + return err + } + } else { + return err + } +} + +func (d *Local) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + srcPath := srcObj.GetPath() + dstPath := filepath.Join(filepath.Dir(srcPath), newName) + err := os.Rename(srcPath, dstPath) + if err != nil { + return err + } + return nil +} + +func (d *Local) Copy(_ context.Context, srcObj, dstDir model.Obj) error { + srcPath := srcObj.GetPath() + dstPath := filepath.Join(dstDir.GetPath(), srcObj.GetName()) + if utils.IsSubPath(srcPath, dstPath) { + return fmt.Errorf("the destination folder is a subfolder of the source folder") + } + // Copy using otiai10/copy to perform more secure & efficient copy + return cp.Copy(srcPath, dstPath, cp.Options{ + Sync: true, // Sync file to disk after copy, may have performance penalty in filesystem such as ZFS + PreserveTimes: true, + PreserveOwner: true, + }) +} + +func (d *Local) Remove(ctx context.Context, obj model.Obj) error { + var err error + if utils.SliceContains([]string{"", "delete permanently"}, d.RecycleBinPath) { + if obj.IsDir() { + err = os.RemoveAll(obj.GetPath()) + } else { + err = os.Remove(obj.GetPath()) + } + } else { + dstPath := filepath.Join(d.RecycleBinPath, obj.GetName()) + if utils.Exists(dstPath) { + dstPath = filepath.Join(d.RecycleBinPath, obj.GetName()+"_"+time.Now().Format("20060102150405")) + } + err = os.Rename(obj.GetPath(), dstPath) + } + if err != nil { + return err + } + return nil +} + +func (d *Local) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + fullPath := filepath.Join(dstDir.GetPath(), stream.GetName()) + out, err := os.Create(fullPath) + if err != nil { + return err + } + defer func() { + _ = out.Close() + if errors.Is(err, context.Canceled) { + _ = os.Remove(fullPath) + } + }() + err = utils.CopyWithCtx(ctx, out, stream, stream.GetSize(), up) + if err != nil { + return err + } + err = os.Chtimes(fullPath, stream.ModTime(), stream.ModTime()) + if err != nil { + log.Errorf("[local] failed to change time of %s: %s", fullPath, err) + } + return nil +} + +var _ driver.Driver = (*Local)(nil) diff --git a/drivers/local/meta.go b/drivers/local/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..5ffac92023420d9127f8ac9d6b9f636945d4269a --- /dev/null +++ b/drivers/local/meta.go @@ -0,0 +1,30 @@ +package local + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + Thumbnail bool `json:"thumbnail" required:"true" help:"enable thumbnail"` + ThumbCacheFolder string `json:"thumb_cache_folder"` + ThumbConcurrency string `json:"thumb_concurrency" default:"16" required:"false" help:"Number of concurrent thumbnail generation goroutines. This controls how many thumbnails can be generated in parallel."` + ShowHidden bool `json:"show_hidden" default:"true" required:"false" help:"show hidden directories and files"` + MkdirPerm string `json:"mkdir_perm" default:"777"` + RecycleBinPath string `json:"recycle_bin_path" default:"delete permanently" help:"path to recycle bin, delete permanently if empty or keep 'delete permanently'"` +} + +var config = driver.Config{ + Name: "Local", + OnlyLocal: true, + LocalSort: true, + NoCache: true, + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Local{} + }) +} diff --git a/drivers/local/token_bucket.go b/drivers/local/token_bucket.go new file mode 100644 index 0000000000000000000000000000000000000000..23c6ebd63b79a72efcb2982b3b2526df990fa11b --- /dev/null +++ b/drivers/local/token_bucket.go @@ -0,0 +1,95 @@ +package local + +import "context" + +type TokenBucket interface { + Take() <-chan struct{} + Put() + Do(context.Context, func() error) error +} + +// StaticTokenBucket is a bucket with a fixed number of tokens, +// where the retrieval and return of tokens are manually controlled. +// In the initial state, the bucket is full. +type StaticTokenBucket struct { + bucket chan struct{} +} + +func NewStaticTokenBucket(size int) StaticTokenBucket { + bucket := make(chan struct{}, size) + for range size { + bucket <- struct{}{} + } + return StaticTokenBucket{bucket: bucket} +} + +func NewStaticTokenBucketWithMigration(oldBucket TokenBucket, size int) StaticTokenBucket { + if oldBucket != nil { + oldStaticBucket, ok := oldBucket.(StaticTokenBucket) + if ok { + oldSize := cap(oldStaticBucket.bucket) + migrateSize := oldSize + if size < migrateSize { + migrateSize = size + } + + bucket := make(chan struct{}, size) + for range size - migrateSize { + bucket <- struct{}{} + } + + if migrateSize != 0 { + go func() { + for range migrateSize { + <-oldStaticBucket.bucket + bucket <- struct{}{} + } + close(oldStaticBucket.bucket) + }() + } + return StaticTokenBucket{bucket: bucket} + } + } + return NewStaticTokenBucket(size) +} + +// Take channel maybe closed when local driver is modified. +// don't call Put method after the channel is closed. +func (b StaticTokenBucket) Take() <-chan struct{} { + return b.bucket +} + +func (b StaticTokenBucket) Put() { + b.bucket <- struct{}{} +} + +func (b StaticTokenBucket) Do(ctx context.Context, f func() error) error { + select { + case <-ctx.Done(): + return ctx.Err() + case _, ok := <-b.Take(): + if ok { + defer b.Put() + } + } + return f() +} + +// NopTokenBucket all function calls to this bucket will success immediately +type NopTokenBucket struct { + nop chan struct{} +} + +func NewNopTokenBucket() NopTokenBucket { + nop := make(chan struct{}) + close(nop) + return NopTokenBucket{nop} +} + +func (b NopTokenBucket) Take() <-chan struct{} { + return b.nop +} + +func (b NopTokenBucket) Put() {} + +func (b NopTokenBucket) Do(_ context.Context, f func() error) error { return f() } diff --git a/drivers/local/util.go b/drivers/local/util.go new file mode 100644 index 0000000000000000000000000000000000000000..b994c2056b77f59c4ce80e908db4b3729dc85ba8 --- /dev/null +++ b/drivers/local/util.go @@ -0,0 +1,111 @@ +package local + +import ( + "bytes" + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/disintegration/imaging" + ffmpeg "github.com/u2takey/ffmpeg-go" +) + +func isSymlinkDir(f fs.FileInfo, path string) bool { + if f.Mode()&os.ModeSymlink == os.ModeSymlink { + dst, err := os.Readlink(filepath.Join(path, f.Name())) + if err != nil { + return false + } + if !filepath.IsAbs(dst) { + dst = filepath.Join(path, dst) + } + stat, err := os.Stat(dst) + if err != nil { + return false + } + return stat.IsDir() + } + return false +} + +func GetSnapshot(videoPath string, frameNum int) (imgData *bytes.Buffer, err error) { + srcBuf := bytes.NewBuffer(nil) + stream := ffmpeg.Input(videoPath). + Filter("select", ffmpeg.Args{fmt.Sprintf("gte(n,%d)", frameNum)}). + Output("pipe:", ffmpeg.KwArgs{"vframes": 1, "format": "image2", "vcodec": "mjpeg"}). + GlobalArgs("-loglevel", "error").Silent(true). + WithOutput(srcBuf, os.Stdout) + if err = stream.Run(); err != nil { + return nil, err + } + return srcBuf, nil +} + +func readDir(dirname string) ([]fs.FileInfo, error) { + f, err := os.Open(dirname) + if err != nil { + return nil, err + } + list, err := f.Readdir(-1) + f.Close() + if err != nil { + return nil, err + } + sort.Slice(list, func(i, j int) bool { return list[i].Name() < list[j].Name() }) + return list, nil +} + +func (d *Local) getThumb(file model.Obj) (*bytes.Buffer, *string, error) { + fullPath := file.GetPath() + thumbPrefix := "alist_thumb_" + thumbName := thumbPrefix + utils.GetMD5EncodeStr(fullPath) + ".png" + if d.ThumbCacheFolder != "" { + // skip if the file is a thumbnail + if strings.HasPrefix(file.GetName(), thumbPrefix) { + return nil, &fullPath, nil + } + thumbPath := filepath.Join(d.ThumbCacheFolder, thumbName) + if utils.Exists(thumbPath) { + return nil, &thumbPath, nil + } + } + var srcBuf *bytes.Buffer + if utils.GetFileType(file.GetName()) == conf.VIDEO { + videoBuf, err := GetSnapshot(fullPath, 10) + if err != nil { + return nil, nil, err + } + srcBuf = videoBuf + } else { + imgData, err := os.ReadFile(fullPath) + if err != nil { + return nil, nil, err + } + imgBuf := bytes.NewBuffer(imgData) + srcBuf = imgBuf + } + + image, err := imaging.Decode(srcBuf, imaging.AutoOrientation(true)) + if err != nil { + return nil, nil, err + } + thumbImg := imaging.Resize(image, 144, 0, imaging.Lanczos) + var buf bytes.Buffer + err = imaging.Encode(&buf, thumbImg, imaging.PNG) + if err != nil { + return nil, nil, err + } + if d.ThumbCacheFolder != "" { + err = os.WriteFile(filepath.Join(d.ThumbCacheFolder, thumbName), buf.Bytes(), 0666) + if err != nil { + return nil, nil, err + } + } + return &buf, nil, nil +} diff --git a/drivers/mediatrack/driver.go b/drivers/mediatrack/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..f0f1ded00872cedbec9ca16616da25fa0f6647be --- /dev/null +++ b/drivers/mediatrack/driver.go @@ -0,0 +1,230 @@ +package mediatrack + +import ( + "context" + "crypto/md5" + "encoding/hex" + "fmt" + "io" + "net/http" + "strconv" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/go-resty/resty/v2" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" +) + +type MediaTrack struct { + model.Storage + Addition +} + +func (d *MediaTrack) Config() driver.Config { + return config +} + +func (d *MediaTrack) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *MediaTrack) Init(ctx context.Context) error { + _, err := d.request("https://kayle.api.mediatrack.cn/users", http.MethodGet, nil, nil) + return err +} + +func (d *MediaTrack) Drop(ctx context.Context) error { + return nil +} + +func (d *MediaTrack) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(f File) (model.Obj, error) { + size, _ := strconv.ParseInt(f.Size, 10, 64) + thumb := "" + if f.File != nil && f.File.Cover != "" { + thumb = "https://nano.mtres.cn/" + f.File.Cover + } + return &Object{ + Object: model.Object{ + ID: f.ID, + Name: f.Title, + Modified: f.UpdatedAt, + IsFolder: f.File == nil, + Size: size, + }, + Thumbnail: model.Thumbnail{Thumbnail: thumb}, + ParentID: dir.GetID(), + }, nil + }) +} + +func (d *MediaTrack) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + url := fmt.Sprintf("https://kayn.api.mediatrack.cn/v1/download_token/asset?asset_id=%s&source_type=project&password=&source_id=%s", + file.GetID(), d.ProjectID) + log.Debugf("media track url: %s", url) + body, err := d.request(url, http.MethodGet, nil, nil) + if err != nil { + return nil, err + } + token := utils.Json.Get(body, "data", "token").ToString() + url = "https://kayn.api.mediatrack.cn/v1/download/redirect?token=" + token + res, err := base.NoRedirectClient.R().Get(url) + if err != nil { + return nil, err + } + log.Debug(res.String()) + link := model.Link{ + URL: url, + } + log.Debugln("res code: ", res.StatusCode()) + if res.StatusCode() == 302 { + link.URL = res.Header().Get("location") + expired := time.Duration(60) * time.Second + link.Expiration = &expired + } + return &link, nil +} + +func (d *MediaTrack) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + url := fmt.Sprintf("https://jayce.api.mediatrack.cn/v3/assets/%s/children", parentDir.GetID()) + _, err := d.request(url, http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "type": 1, + "title": dirName, + }) + }, nil) + return err +} + +func (d *MediaTrack) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + data := base.Json{ + "parent_id": dstDir.GetID(), + "ids": []string{srcObj.GetID()}, + } + url := "https://jayce.api.mediatrack.cn/v4/assets/batch/move" + _, err := d.request(url, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *MediaTrack) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + url := "https://jayce.api.mediatrack.cn/v3/assets/" + srcObj.GetID() + data := base.Json{ + "title": newName, + } + _, err := d.request(url, http.MethodPut, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *MediaTrack) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + data := base.Json{ + "parent_id": dstDir.GetID(), + "ids": []string{srcObj.GetID()}, + } + url := "https://jayce.api.mediatrack.cn/v4/assets/batch/clone" + _, err := d.request(url, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *MediaTrack) Remove(ctx context.Context, obj model.Obj) error { + var parentID string + if o, ok := obj.(*Object); ok { + parentID = o.ParentID + } else { + return fmt.Errorf("obj is not local Object") + } + data := base.Json{ + "origin_id": parentID, + "ids": []string{obj.GetID()}, + } + url := "https://jayce.api.mediatrack.cn/v4/assets/batch/delete" + _, err := d.request(url, http.MethodDelete, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + src := "assets/" + uuid.New().String() + var resp UploadResp + _, err := d.request("https://jayce.api.mediatrack.cn/v3/storage/tokens/asset", http.MethodGet, func(req *resty.Request) { + req.SetQueryParam("src", src) + }, &resp) + if err != nil { + return err + } + credential := resp.Data.Credentials + cfg := &aws.Config{ + Credentials: credentials.NewStaticCredentials(credential.TmpSecretID, credential.TmpSecretKey, credential.Token), + Region: &resp.Data.Region, + Endpoint: aws.String("cos.accelerate.myqcloud.com"), + } + s, err := session.NewSession(cfg) + if err != nil { + return err + } + tempFile, err := stream.CacheFullInTempFile() + if err != nil { + return err + } + defer func() { + _ = tempFile.Close() + }() + uploader := s3manager.NewUploader(s) + if stream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { + uploader.PartSize = stream.GetSize() / (s3manager.MaxUploadParts - 1) + } + input := &s3manager.UploadInput{ + Bucket: &resp.Data.Bucket, + Key: &resp.Data.Object, + Body: tempFile, + } + _, err = uploader.UploadWithContext(ctx, input) + if err != nil { + return err + } + url := fmt.Sprintf("https://jayce.api.mediatrack.cn/v3/assets/%s/children", dstDir.GetID()) + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return err + } + h := md5.New() + _, err = utils.CopyWithBuffer(h, tempFile) + if err != nil { + return err + } + hash := hex.EncodeToString(h.Sum(nil)) + data := base.Json{ + "category": 0, + "description": stream.GetName(), + "hash": hash, + "mime": stream.GetMimetype(), + "size": stream.GetSize(), + "src": src, + "title": stream.GetName(), + "type": 0, + } + _, err = d.request(url, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +var _ driver.Driver = (*MediaTrack)(nil) diff --git a/drivers/mediatrack/meta.go b/drivers/mediatrack/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..47f112c357391785991ffdae3869afe2982870e0 --- /dev/null +++ b/drivers/mediatrack/meta.go @@ -0,0 +1,24 @@ +package mediatrack + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + AccessToken string `json:"access_token" required:"true"` + ProjectID string `json:"project_id"` + driver.RootID + OrderBy string `json:"order_by" type:"select" options:"updated_at,title,size" default:"title"` + OrderDesc bool `json:"order_desc"` +} + +var config = driver.Config{ + Name: "MediaTrack", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &MediaTrack{} + }) +} diff --git a/drivers/mediatrack/types.go b/drivers/mediatrack/types.go new file mode 100644 index 0000000000000000000000000000000000000000..e8805275ed8a884c630e74d40270dc6bf426d5a4 --- /dev/null +++ b/drivers/mediatrack/types.go @@ -0,0 +1,72 @@ +package mediatrack + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type BaseResp struct { + Status string `json:"status"` + Message string `json:"message"` +} +type File struct { + Category int `json:"category"` + ChildAssets []interface{} `json:"childAssets"` + CommentCount int `json:"comment_count"` + CoverAsset interface{} `json:"cover_asset"` + CoverAssetID string `json:"cover_asset_id"` + CreatedAt time.Time `json:"created_at"` + DeletedAt string `json:"deleted_at"` + Description string `json:"description"` + File *struct { + Cover string `json:"cover"` + Src string `json:"src"` + } `json:"file"` + //FileID string `json:"file_id"` + ID string `json:"id"` + + Size string `json:"size"` + Thumbnails []interface{} `json:"thumbnails"` + Title string `json:"title"` + UpdatedAt time.Time `json:"updated_at"` +} + +type ChildrenResp struct { + Status string `json:"status"` + Data struct { + Total int `json:"total"` + Assets []File `json:"assets"` + } `json:"data"` + Path string `json:"path"` + TraceID string `json:"trace_id"` + RequestID string `json:"requestId"` +} + +type UploadResp struct { + Status string `json:"status"` + Data struct { + Credentials struct { + TmpSecretID string `json:"TmpSecretId"` + TmpSecretKey string `json:"TmpSecretKey"` + Token string `json:"Token"` + ExpiredTime int `json:"ExpiredTime"` + Expiration time.Time `json:"Expiration"` + StartTime int `json:"StartTime"` + } `json:"credentials"` + Object string `json:"object"` + Bucket string `json:"bucket"` + Region string `json:"region"` + URL string `json:"url"` + Size string `json:"size"` + } `json:"data"` + Path string `json:"path"` + TraceID string `json:"trace_id"` + RequestID string `json:"requestId"` +} + +type Object struct { + model.Object + model.Thumbnail + ParentID string +} diff --git a/drivers/mediatrack/util.go b/drivers/mediatrack/util.go new file mode 100644 index 0000000000000000000000000000000000000000..37ca0b3d09c1ce4411e5b20c7f247bc9b0149b1b --- /dev/null +++ b/drivers/mediatrack/util.go @@ -0,0 +1,69 @@ +package mediatrack + +import ( + "errors" + "fmt" + "net/http" + "strconv" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +// do others that not defined in Driver interface + +func (d *MediaTrack) request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeader("Authorization", "Bearer "+d.AccessToken) + if callback != nil { + callback(req) + } + var e BaseResp + req.SetResult(&e) + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + log.Debugln(res.String()) + if e.Status != "SUCCESS" { + return nil, errors.New(e.Message) + } + if resp != nil { + err = utils.Json.Unmarshal(res.Body(), resp) + } + return res.Body(), err +} + +func (d *MediaTrack) getFiles(parentId string) ([]File, error) { + files := make([]File, 0) + url := fmt.Sprintf("https://jayce.api.mediatrack.cn/v4/assets/%s/children", parentId) + sort := "" + if d.OrderBy != "" { + if d.OrderDesc { + sort = "-" + } + sort += d.OrderBy + } + page := 1 + for { + var resp ChildrenResp + _, err := d.request(url, http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "page": strconv.Itoa(page), + "size": "50", + "sort": sort, + }) + }, &resp) + if err != nil { + return nil, err + } + if len(resp.Data.Assets) == 0 { + break + } + page++ + files = append(files, resp.Data.Assets...) + } + return files, nil +} diff --git a/drivers/mega/driver.go b/drivers/mega/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..162aeef37e0f8070c1361e9c7743e20116192e58 --- /dev/null +++ b/drivers/mega/driver.go @@ -0,0 +1,195 @@ +package mega + +import ( + "context" + "errors" + "fmt" + "io" + "time" + + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/pquerna/otp/totp" + "github.com/rclone/rclone/lib/readers" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + log "github.com/sirupsen/logrus" + "github.com/t3rm1n4l/go-mega" +) + +type Mega struct { + model.Storage + Addition + c *mega.Mega +} + +func (d *Mega) Config() driver.Config { + return config +} + +func (d *Mega) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Mega) Init(ctx context.Context) error { + var twoFACode = d.TwoFACode + d.c = mega.New() + if d.TwoFASecret != "" { + code, err := totp.GenerateCode(d.TwoFASecret, time.Now()) + if err != nil { + return fmt.Errorf("generate totp code failed: %w", err) + } + twoFACode = code + } + return d.c.MultiFactorLogin(d.Email, d.Password, twoFACode) +} + +func (d *Mega) Drop(ctx context.Context) error { + return nil +} + +func (d *Mega) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if node, ok := dir.(*MegaNode); ok { + nodes, err := d.c.FS.GetChildren(node.n) + if err != nil { + return nil, err + } + res := make([]model.Obj, 0) + for i := range nodes { + n := nodes[i] + if n.GetType() == mega.FILE || n.GetType() == mega.FOLDER { + res = append(res, &MegaNode{n}) + } + } + return res, nil + } + log.Errorf("can't convert: %+v", dir) + return nil, fmt.Errorf("unable to convert dir to mega n") +} + +func (d *Mega) GetRoot(ctx context.Context) (model.Obj, error) { + n := d.c.FS.GetRoot() + log.Debugf("mega root: %+v", *n) + return &MegaNode{n}, nil +} + +func (d *Mega) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if node, ok := file.(*MegaNode); ok { + + //down, err := d.c.NewDownload(n.Node) + //if err != nil { + // return nil, fmt.Errorf("open download file failed: %w", err) + //} + + size := file.GetSize() + var finalClosers utils.Closers + resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + length := httpRange.Length + if httpRange.Length >= 0 && httpRange.Start+httpRange.Length >= size { + length = -1 + } + var down *mega.Download + err := utils.Retry(3, time.Second, func() (err error) { + down, err = d.c.NewDownload(node.n) + return err + }) + if err != nil { + return nil, fmt.Errorf("open download file failed: %w", err) + } + oo := &openObject{ + ctx: ctx, + d: down, + skip: httpRange.Start, + } + finalClosers.Add(oo) + + return readers.NewLimitedReadCloser(oo, length), nil + } + resultRangeReadCloser := &model.RangeReadCloser{RangeReader: resultRangeReader, Closers: finalClosers} + resultLink := &model.Link{ + RangeReadCloser: resultRangeReadCloser, + } + return resultLink, nil + } + return nil, fmt.Errorf("unable to convert dir to mega n") +} + +func (d *Mega) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + if parentNode, ok := parentDir.(*MegaNode); ok { + _, err := d.c.CreateDir(dirName, parentNode.n) + return err + } + return fmt.Errorf("unable to convert dir to mega n") +} + +func (d *Mega) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + if srcNode, ok := srcObj.(*MegaNode); ok { + if dstNode, ok := dstDir.(*MegaNode); ok { + return d.c.Move(srcNode.n, dstNode.n) + } + } + return fmt.Errorf("unable to convert dir to mega n") +} + +func (d *Mega) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + if srcNode, ok := srcObj.(*MegaNode); ok { + return d.c.Rename(srcNode.n, newName) + } + return fmt.Errorf("unable to convert dir to mega n") +} + +func (d *Mega) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotImplement +} + +func (d *Mega) Remove(ctx context.Context, obj model.Obj) error { + if node, ok := obj.(*MegaNode); ok { + return d.c.Delete(node.n, false) + } + return fmt.Errorf("unable to convert dir to mega n") +} + +func (d *Mega) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + if dstNode, ok := dstDir.(*MegaNode); ok { + u, err := d.c.NewUpload(dstNode.n, stream.GetName(), stream.GetSize()) + if err != nil { + return err + } + + for id := 0; id < u.Chunks(); id++ { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + _, chkSize, err := u.ChunkLocation(id) + if err != nil { + return err + } + chunk := make([]byte, chkSize) + n, err := io.ReadFull(stream, chunk) + if err != nil && err != io.EOF { + return err + } + if n != len(chunk) { + return errors.New("chunk too short") + } + + err = u.UploadChunk(id, chunk) + if err != nil { + return err + } + up(float64(id) * 100 / float64(u.Chunks())) + } + + _, err = u.Finish() + return err + } + return fmt.Errorf("unable to convert dir to mega n") +} + +//func (d *Mega) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*Mega)(nil) diff --git a/drivers/mega/meta.go b/drivers/mega/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..d0758637bb3d2a854975329c1d4d8c0f3e3c08d5 --- /dev/null +++ b/drivers/mega/meta.go @@ -0,0 +1,28 @@ +package mega + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + // Usually one of two + //driver.RootPath + //driver.RootID + Email string `json:"email" required:"true"` + Password string `json:"password" required:"true"` + TwoFACode string `json:"two_fa_code" required:"false" help:"2FA 6-digit code, filling in the 2FA code alone will not support reloading driver"` + TwoFASecret string `json:"two_fa_secret" required:"false" help:"2FA secret"` +} + +var config = driver.Config{ + Name: "Mega_nz", + LocalSort: true, + OnlyLocal: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Mega{} + }) +} diff --git a/drivers/mega/types.go b/drivers/mega/types.go new file mode 100644 index 0000000000000000000000000000000000000000..3046d449c3458ebad4b1c81a1cbcd0a6f91689a2 --- /dev/null +++ b/drivers/mega/types.go @@ -0,0 +1,48 @@ +package mega + +import ( + "github.com/alist-org/alist/v3/pkg/utils" + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/t3rm1n4l/go-mega" +) + +type MegaNode struct { + n *mega.Node +} + +func (m *MegaNode) GetSize() int64 { + return m.n.GetSize() +} + +func (m *MegaNode) GetName() string { + return m.n.GetName() +} + +func (m *MegaNode) CreateTime() time.Time { + return m.n.GetTimeStamp() +} + +func (m *MegaNode) GetHash() utils.HashInfo { + //Meganz use md5, but can't get the original file hash, due to it's encrypted in the cloud + return utils.HashInfo{} +} + +func (m *MegaNode) ModTime() time.Time { + return m.n.GetTimeStamp() +} + +func (m *MegaNode) IsDir() bool { + return m.n.GetType() == mega.FOLDER || m.n.GetType() == mega.ROOT +} + +func (m *MegaNode) GetID() string { + return m.n.GetHash() +} + +func (m *MegaNode) GetPath() string { + return "" +} + +var _ model.Obj = (*MegaNode)(nil) diff --git a/drivers/mega/util.go b/drivers/mega/util.go new file mode 100644 index 0000000000000000000000000000000000000000..f5ad25444c6e40003911b47b20581daae9b931e3 --- /dev/null +++ b/drivers/mega/util.go @@ -0,0 +1,92 @@ +package mega + +import ( + "context" + "fmt" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/t3rm1n4l/go-mega" + "io" + "sync" + "time" +) + +// do others that not defined in Driver interface +// openObject represents a download in progress +type openObject struct { + ctx context.Context + mu sync.Mutex + d *mega.Download + id int + skip int64 + chunk []byte + closed bool +} + +// get the next chunk +func (oo *openObject) getChunk(ctx context.Context) (err error) { + if oo.id >= oo.d.Chunks() { + return io.EOF + } + var chunk []byte + err = utils.Retry(3, time.Second, func() (err error) { + chunk, err = oo.d.DownloadChunk(oo.id) + return err + }) + if err != nil { + return err + } + oo.id++ + oo.chunk = chunk + return nil +} + +// Read reads up to len(p) bytes into p. +func (oo *openObject) Read(p []byte) (n int, err error) { + oo.mu.Lock() + defer oo.mu.Unlock() + if oo.closed { + return 0, fmt.Errorf("read on closed file") + } + // Skip data at the start if requested + for oo.skip > 0 { + _, size, err := oo.d.ChunkLocation(oo.id) + if err != nil { + return 0, err + } + if oo.skip < int64(size) { + break + } + oo.id++ + oo.skip -= int64(size) + } + if len(oo.chunk) == 0 { + err = oo.getChunk(oo.ctx) + if err != nil { + return 0, err + } + if oo.skip > 0 { + oo.chunk = oo.chunk[oo.skip:] + oo.skip = 0 + } + } + n = copy(p, oo.chunk) + oo.chunk = oo.chunk[n:] + return n, nil +} + +// Close closed the file - MAC errors are reported here +func (oo *openObject) Close() (err error) { + oo.mu.Lock() + defer oo.mu.Unlock() + if oo.closed { + return nil + } + err = utils.Retry(3, 500*time.Millisecond, func() (err error) { + return oo.d.Finish() + }) + if err != nil { + return fmt.Errorf("failed to finish download: %w", err) + } + oo.closed = true + return nil +} diff --git a/drivers/mopan/driver.go b/drivers/mopan/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..369ec83b64d0baf3f3296db3e5f496b5697fc84c --- /dev/null +++ b/drivers/mopan/driver.go @@ -0,0 +1,367 @@ +package mopan + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/errgroup" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/avast/retry-go" + "github.com/foxxorcat/mopan-sdk-go" + log "github.com/sirupsen/logrus" +) + +type MoPan struct { + model.Storage + Addition + client *mopan.MoClient + + userID string + uploadThread int +} + +func (d *MoPan) Config() driver.Config { + return config +} + +func (d *MoPan) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *MoPan) Init(ctx context.Context) error { + d.uploadThread, _ = strconv.Atoi(d.UploadThread) + if d.uploadThread < 1 || d.uploadThread > 32 { + d.uploadThread, d.UploadThread = 3, "3" + } + + defer func() { d.SMSCode = "" }() + + login := func() (err error) { + var loginData *mopan.LoginResp + if d.SMSCode != "" { + loginData, err = d.client.LoginBySmsStep2(d.Phone, d.SMSCode) + } else { + loginData, err = d.client.Login(d.Phone, d.Password) + } + if err != nil { + return err + } + d.client.SetAuthorization(loginData.Token) + + info, err := d.client.GetUserInfo() + if err != nil { + return err + } + d.userID = info.UserID + log.Debugf("[mopan] Phone: %s UserCloudStorageRelations: %+v", d.Phone, loginData.UserCloudStorageRelations) + cloudCircleApp, _ := d.client.QueryAllCloudCircleApp() + log.Debugf("[mopan] Phone: %s CloudCircleApp: %+v", d.Phone, cloudCircleApp) + if d.RootFolderID == "" { + for _, userCloudStorage := range loginData.UserCloudStorageRelations { + if userCloudStorage.Path == "/文件" { + d.RootFolderID = userCloudStorage.FolderID + } + } + } + return nil + } + d.client = mopan.NewMoClientWithRestyClient(base.NewRestyClient()). + SetRestyClient(base.RestyClient). + SetOnAuthorizationExpired(func(_ error) error { + err := login() + if err != nil { + d.Status = err.Error() + op.MustSaveDriverStorage(d) + } + return err + }) + + var deviceInfo mopan.DeviceInfo + if strings.TrimSpace(d.DeviceInfo) != "" && utils.Json.UnmarshalFromString(d.DeviceInfo, &deviceInfo) == nil { + d.client.SetDeviceInfo(&deviceInfo) + } + d.DeviceInfo, _ = utils.Json.MarshalToString(d.client.GetDeviceInfo()) + + if strings.Contains(d.SMSCode, "send") { + if _, err := d.client.LoginBySms(d.Phone); err != nil { + return err + } + return errors.New("please enter the SMS code") + } + return login() +} + +func (d *MoPan) Drop(ctx context.Context) error { + d.client = nil + d.userID = "" + return nil +} + +func (d *MoPan) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + var files []model.Obj + for page := 1; ; page++ { + data, err := d.client.QueryFiles(dir.GetID(), page, mopan.WarpParamOption( + func(j mopan.Json) { + j["orderBy"] = d.OrderBy + j["descending"] = d.OrderDirection == "desc" + }, + mopan.ParamOptionShareFile(d.CloudID), + )) + if err != nil { + return nil, err + } + + if len(data.FileListAO.FileList)+len(data.FileListAO.FolderList) == 0 { + break + } + + log.Debugf("[mopan] Phone: %s folder: %+v", d.Phone, data.FileListAO.FolderList) + files = append(files, utils.MustSliceConvert(data.FileListAO.FolderList, folderToObj)...) + files = append(files, utils.MustSliceConvert(data.FileListAO.FileList, fileToObj)...) + } + return files, nil +} + +func (d *MoPan) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + data, err := d.client.GetFileDownloadUrl(file.GetID(), mopan.WarpParamOption(mopan.ParamOptionShareFile(d.CloudID))) + if err != nil { + return nil, err + } + + data.DownloadUrl = strings.Replace(strings.ReplaceAll(data.DownloadUrl, "&", "&"), "http://", "https://", 1) + res, err := base.NoRedirectClient.R().SetDoNotParseResponse(true).SetContext(ctx).Get(data.DownloadUrl) + if err != nil { + return nil, err + } + defer func() { + _ = res.RawBody().Close() + }() + if res.StatusCode() == 302 { + data.DownloadUrl = res.Header().Get("location") + } + + return &model.Link{ + URL: data.DownloadUrl, + }, nil +} + +func (d *MoPan) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + f, err := d.client.CreateFolder(dirName, parentDir.GetID(), mopan.WarpParamOption( + mopan.ParamOptionShareFile(d.CloudID), + )) + if err != nil { + return nil, err + } + return folderToObj(*f), nil +} + +func (d *MoPan) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + return d.newTask(srcObj, dstDir, mopan.TASK_MOVE) +} + +func (d *MoPan) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + if srcObj.IsDir() { + _, err := d.client.RenameFolder(srcObj.GetID(), newName, mopan.WarpParamOption( + mopan.ParamOptionShareFile(d.CloudID), + )) + if err != nil { + return nil, err + } + } else { + _, err := d.client.RenameFile(srcObj.GetID(), newName, mopan.WarpParamOption( + mopan.ParamOptionShareFile(d.CloudID), + )) + if err != nil { + return nil, err + } + } + return CloneObj(srcObj, srcObj.GetID(), newName), nil +} + +func (d *MoPan) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + return d.newTask(srcObj, dstDir, mopan.TASK_COPY) +} + +func (d *MoPan) newTask(srcObj, dstDir model.Obj, taskType mopan.TaskType) (model.Obj, error) { + param := mopan.TaskParam{ + UserOrCloudID: d.userID, + Source: 1, + TaskType: taskType, + TargetSource: 1, + TargetUserOrCloudID: d.userID, + TargetType: 1, + TargetFolderID: dstDir.GetID(), + TaskStatusDetailDTOList: []mopan.TaskFileParam{ + { + FileID: srcObj.GetID(), + IsFolder: srcObj.IsDir(), + FileName: srcObj.GetName(), + }, + }, + } + if d.CloudID != "" { + param.UserOrCloudID = d.CloudID + param.Source = 2 + param.TargetSource = 2 + param.TargetUserOrCloudID = d.CloudID + } + + task, err := d.client.AddBatchTask(param) + if err != nil { + return nil, err + } + + for count := 0; count < 5; count++ { + stat, err := d.client.CheckBatchTask(mopan.TaskCheckParam{ + TaskId: task.TaskIDList[0], + TaskType: task.TaskType, + TargetType: 1, + TargetFolderID: task.TargetFolderID, + TargetSource: param.TargetSource, + TargetUserOrCloudID: param.TargetUserOrCloudID, + }) + if err != nil { + return nil, err + } + + switch stat.TaskStatus { + case 2: + if err := d.client.CancelBatchTask(stat.TaskID, task.TaskType); err != nil { + return nil, err + } + return nil, errors.New("file name conflict") + case 4: + if task.TaskType == mopan.TASK_MOVE { + return CloneObj(srcObj, srcObj.GetID(), srcObj.GetName()), nil + } + return CloneObj(srcObj, stat.SuccessedFileIDList[0], srcObj.GetName()), nil + } + time.Sleep(time.Second) + } + return nil, nil +} + +func (d *MoPan) Remove(ctx context.Context, obj model.Obj) error { + _, err := d.client.DeleteToRecycle([]mopan.TaskFileParam{ + { + FileID: obj.GetID(), + IsFolder: obj.IsDir(), + FileName: obj.GetName(), + }, + }, mopan.WarpParamOption(mopan.ParamOptionShareFile(d.CloudID))) + return err +} + +func (d *MoPan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + file, err := stream.CacheFullInTempFile() + if err != nil { + return nil, err + } + defer func() { + _ = file.Close() + }() + + // step.1 + uploadPartData, err := mopan.InitUploadPartData(ctx, mopan.UpdloadFileParam{ + ParentFolderId: dstDir.GetID(), + FileName: stream.GetName(), + FileSize: stream.GetSize(), + File: file, + }) + if err != nil { + return nil, err + } + + // 尝试恢复进度 + initUpdload, ok := base.GetUploadProgress[*mopan.InitMultiUploadData](d, d.client.Authorization, uploadPartData.FileMd5) + if !ok { + // step.2 + initUpdload, err = d.client.InitMultiUpload(ctx, *uploadPartData, mopan.WarpParamOption( + mopan.ParamOptionShareFile(d.CloudID), + )) + if err != nil { + return nil, err + } + } + + if !initUpdload.FileDataExists { + // utils.Log.Error(d.client.CloudDiskStartBusiness()) + + threadG, upCtx := errgroup.NewGroupWithContext(ctx, d.uploadThread, + retry.Attempts(3), + retry.Delay(time.Second), + retry.DelayType(retry.BackOffDelay)) + + // step.3 + parts, err := d.client.GetAllMultiUploadUrls(initUpdload.UploadFileID, initUpdload.PartInfos) + if err != nil { + return nil, err + } + + for i, part := range parts { + if utils.IsCanceled(upCtx) { + break + } + i, part, byteSize := i, part, initUpdload.PartSize + if part.PartNumber == uploadPartData.PartTotal { + byteSize = initUpdload.LastPartSize + } + + // step.4 + threadG.Go(func(ctx context.Context) error { + req, err := part.NewRequest(ctx, io.NewSectionReader(file, int64(part.PartNumber-1)*initUpdload.PartSize, byteSize)) + if err != nil { + return err + } + req.ContentLength = byteSize + resp, err := base.HttpClient.Do(req) + if err != nil { + return err + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("upload err,code=%d", resp.StatusCode) + } + up(100 * float64(threadG.Success()) / float64(len(parts))) + initUpdload.PartInfos[i] = "" + return nil + }) + } + if err = threadG.Wait(); err != nil { + if errors.Is(err, context.Canceled) { + initUpdload.PartInfos = utils.SliceFilter(initUpdload.PartInfos, func(s string) bool { return s != "" }) + base.SaveUploadProgress(d, initUpdload, d.client.Authorization, uploadPartData.FileMd5) + } + return nil, err + } + } + //step.5 + uFile, err := d.client.CommitMultiUploadFile(initUpdload.UploadFileID, nil) + if err != nil { + return nil, err + } + return &model.Object{ + ID: uFile.UserFileID, + Name: uFile.FileName, + Size: int64(uFile.FileSize), + Modified: time.Time(uFile.CreateDate), + }, nil +} + +var _ driver.Driver = (*MoPan)(nil) +var _ driver.MkdirResult = (*MoPan)(nil) +var _ driver.MoveResult = (*MoPan)(nil) +var _ driver.RenameResult = (*MoPan)(nil) +var _ driver.Remove = (*MoPan)(nil) +var _ driver.CopyResult = (*MoPan)(nil) +var _ driver.PutResult = (*MoPan)(nil) diff --git a/drivers/mopan/meta.go b/drivers/mopan/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..c111fedc16a9d182e9f0d0c1cd0d9c1d982c8edc --- /dev/null +++ b/drivers/mopan/meta.go @@ -0,0 +1,40 @@ +package mopan + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + Phone string `json:"phone" required:"true"` + Password string `json:"password" required:"true"` + SMSCode string `json:"sms_code" help:"input 'send' send sms "` + + RootFolderID string `json:"root_folder_id" default:""` + + CloudID string `json:"cloud_id"` + + OrderBy string `json:"order_by" type:"select" options:"filename,filesize,lastOpTime" default:"filename"` + OrderDirection string `json:"order_direction" type:"select" options:"asc,desc" default:"asc"` + + DeviceInfo string `json:"device_info"` + + UploadThread string `json:"upload_thread" default:"3" help:"1<=thread<=32"` +} + +func (a *Addition) GetRootId() string { + return a.RootFolderID +} + +var config = driver.Config{ + Name: "MoPan", + // DefaultRoot: "root, / or other", + CheckStatus: true, + Alert: "warning|This network disk may store your password in clear text. Please set your password carefully", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &MoPan{} + }) +} diff --git a/drivers/mopan/types.go b/drivers/mopan/types.go new file mode 100644 index 0000000000000000000000000000000000000000..54b02f9a5b2bae6de98f6a5504ba72f1294f9d1e --- /dev/null +++ b/drivers/mopan/types.go @@ -0,0 +1 @@ +package mopan diff --git a/drivers/mopan/util.go b/drivers/mopan/util.go new file mode 100644 index 0000000000000000000000000000000000000000..e6b20f9a22c4bf101d56dcf7b8958038726cdbdd --- /dev/null +++ b/drivers/mopan/util.go @@ -0,0 +1,65 @@ +package mopan + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/foxxorcat/mopan-sdk-go" +) + +func fileToObj(f mopan.File) model.Obj { + return &model.ObjThumb{ + Object: model.Object{ + ID: string(f.ID), + Name: f.Name, + Size: int64(f.Size), + Modified: time.Time(f.LastOpTime), + Ctime: time.Time(f.CreateDate), + HashInfo: utils.NewHashInfo(utils.MD5, f.Md5), + }, + Thumbnail: model.Thumbnail{ + Thumbnail: f.Icon.SmallURL, + }, + } +} + +func folderToObj(f mopan.Folder) model.Obj { + return &model.Object{ + ID: string(f.ID), + Name: f.Name, + Modified: time.Time(f.LastOpTime), + Ctime: time.Time(f.CreateDate), + IsFolder: true, + } +} + +func CloneObj(o model.Obj, newID, newName string) model.Obj { + if o.IsDir() { + return &model.Object{ + ID: newID, + Name: newName, + IsFolder: true, + Modified: o.ModTime(), + Ctime: o.CreateTime(), + } + } + + thumb := "" + if o, ok := o.(model.Thumb); ok { + thumb = o.Thumb() + } + return &model.ObjThumb{ + Object: model.Object{ + ID: newID, + Name: newName, + Size: o.GetSize(), + Modified: o.ModTime(), + Ctime: o.CreateTime(), + HashInfo: o.GetHash(), + }, + Thumbnail: model.Thumbnail{ + Thumbnail: thumb, + }, + } +} diff --git a/drivers/netease_music/crypto.go b/drivers/netease_music/crypto.go new file mode 100644 index 0000000000000000000000000000000000000000..76ff65486aca8cce8c41afcf9956c4f913fcb94b --- /dev/null +++ b/drivers/netease_music/crypto.go @@ -0,0 +1,135 @@ +package netease_music + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/md5" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "math/big" + "strings" + + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/pkg/utils/random" +) + +var ( + linuxapiKey = []byte("rFgB&h#%2?^eDg:Q") + eapiKey = []byte("e82ckenh8dichen8") + iv = []byte("0102030405060708") + presetKey = []byte("0CoJUm6Qyw8W8jud") + publicKey = []byte("-----BEGIN PUBLIC KEY-----\nMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDgtQn2JZ34ZC28NWYpAUd98iZ37BUrX/aKzmFbt7clFSs6sXqHauqKWqdtLkF2KexO40H1YTX8z2lSgBBOAxLsvaklV8k4cBFK9snQXE9/DDaFt6Rr7iVZMldczhC0JNgTz+SHXT6CBHuX3e9SdB1Ua44oncaTWz7OBGLbCiK45wIDAQAB\n-----END PUBLIC KEY-----") + stdChars = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") +) + +func aesKeyPending(key []byte) []byte { + k := len(key) + count := 0 + switch true { + case k <= 16: + count = 16 - k + case k <= 24: + count = 24 - k + case k <= 32: + count = 32 - k + default: + return key[:32] + } + if count == 0 { + return key + } + + return append(key, bytes.Repeat([]byte{0}, count)...) +} + +func pkcs7Padding(src []byte, blockSize int) []byte { + padding := blockSize - len(src)%blockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(src, padtext...) +} + +func aesCBCEncrypt(src, key, iv []byte) []byte { + block, _ := aes.NewCipher(aesKeyPending(key)) + src = pkcs7Padding(src, block.BlockSize()) + dst := make([]byte, len(src)) + + mode := cipher.NewCBCEncrypter(block, iv) + mode.CryptBlocks(dst, src) + + return dst +} + +func aesECBEncrypt(src, key []byte) []byte { + block, _ := aes.NewCipher(aesKeyPending(key)) + + src = pkcs7Padding(src, block.BlockSize()) + dst := make([]byte, len(src)) + + ecbCryptBlocks(block, dst, src) + + return dst +} + +func ecbCryptBlocks(block cipher.Block, dst, src []byte) { + bs := block.BlockSize() + + for len(src) > 0 { + block.Encrypt(dst, src[:bs]) + src = src[bs:] + dst = dst[bs:] + } +} + +func rsaEncrypt(buffer, key []byte) []byte { + buffers := make([]byte, 128-16, 128) + buffers = append(buffers, buffer...) + block, _ := pem.Decode(key) + pubInterface, _ := x509.ParsePKIXPublicKey(block.Bytes) + pub := pubInterface.(*rsa.PublicKey) + c := new(big.Int).SetBytes([]byte(buffers)) + return c.Exp(c, big.NewInt(int64(pub.E)), pub.N).Bytes() +} + +func getSecretKey() ([]byte, []byte) { + key := make([]byte, 16) + reversed := make([]byte, 16) + for i := 0; i < 16; i++ { + result := stdChars[random.RangeInt64(0, 62)] + key[i] = result + reversed[15-i] = result + } + return key, reversed +} + +func weapi(data map[string]string) map[string]string { + text, _ := utils.Json.Marshal(data) + secretKey, reversedKey := getSecretKey() + params := []byte(base64.StdEncoding.EncodeToString(aesCBCEncrypt(text, presetKey, iv))) + return map[string]string{ + "params": base64.StdEncoding.EncodeToString(aesCBCEncrypt(params, reversedKey, iv)), + "encSecKey": hex.EncodeToString(rsaEncrypt(secretKey, publicKey)), + } +} + +func eapi(url string, data map[string]interface{}) map[string]string { + text, _ := utils.Json.Marshal(data) + msg := "nobody" + url + "use" + string(text) + "md5forencrypt" + h := md5.New() + h.Write([]byte(msg)) + digest := hex.EncodeToString(h.Sum(nil)) + params := []byte(url + "-36cd479b6b5-" + string(text) + "-36cd479b6b5-" + digest) + return map[string]string{ + "params": hex.EncodeToString(aesECBEncrypt(params, eapiKey)), + } +} + +func linuxapi(data map[string]interface{}) map[string]string { + text, _ := utils.Json.Marshal(data) + return map[string]string{ + "eparams": strings.ToUpper(hex.EncodeToString(aesECBEncrypt(text, linuxapiKey))), + } +} diff --git a/drivers/netease_music/driver.go b/drivers/netease_music/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..c0d103de0d972e7d1bc0774a5c9427f4ab3c2b3d --- /dev/null +++ b/drivers/netease_music/driver.go @@ -0,0 +1,110 @@ +package netease_music + +import ( + "context" + "strings" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + _ "golang.org/x/image/webp" +) + +type NeteaseMusic struct { + model.Storage + Addition + + csrfToken string + musicU string + fileMapByName map[string]model.Obj +} + +func (d *NeteaseMusic) Config() driver.Config { + return config +} + +func (d *NeteaseMusic) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *NeteaseMusic) Init(ctx context.Context) error { + d.csrfToken = d.Addition.getCookie("__csrf") + d.musicU = d.Addition.getCookie("MUSIC_U") + + if d.csrfToken == "" || d.musicU == "" { + return errs.EmptyToken + } + + return nil +} + +func (d *NeteaseMusic) Drop(ctx context.Context) error { + return nil +} + +func (d *NeteaseMusic) Get(ctx context.Context, path string) (model.Obj, error) { + if path == "/" { + return &model.Object{ + IsFolder: true, + Path: path, + }, nil + } + + fragments := strings.Split(path, "/") + if len(fragments) > 1 { + fileName := fragments[1] + if strings.HasSuffix(fileName, ".lrc") { + lrc := d.fileMapByName[fileName] + return d.getLyricObj(lrc) + } + if song, ok := d.fileMapByName[fileName]; ok { + return song, nil + } else { + return nil, errs.ObjectNotFound + } + } + + return nil, errs.ObjectNotFound +} + +func (d *NeteaseMusic) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + return d.getSongObjs(args) +} + +func (d *NeteaseMusic) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if lrc, ok := file.(*LyricObj); ok { + if args.Type == "parsed" { + return lrc.getLyricLink(), nil + } else { + return lrc.getProxyLink(args), nil + } + } + + return d.getSongLink(file) +} + +func (d *NeteaseMusic) Remove(ctx context.Context, obj model.Obj) error { + return d.removeSongObj(obj) +} + +func (d *NeteaseMusic) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + return d.putSongStream(stream) +} + +func (d *NeteaseMusic) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotSupport +} + +func (d *NeteaseMusic) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotSupport +} + +func (d *NeteaseMusic) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + return errs.NotSupport +} + +func (d *NeteaseMusic) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + return errs.NotSupport +} + +var _ driver.Driver = (*NeteaseMusic)(nil) diff --git a/drivers/netease_music/meta.go b/drivers/netease_music/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..8ddfd7281786b09fbdcbd5cd7976e9cd84da0171 --- /dev/null +++ b/drivers/netease_music/meta.go @@ -0,0 +1,32 @@ +package netease_music + +import ( + "regexp" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + Cookie string `json:"cookie" type:"text" required:"true" help:""` + SongLimit uint64 `json:"song_limit" default:"200" type:"number" help:"only get 200 songs by default"` +} + +func (ad *Addition) getCookie(name string) string { + re := regexp.MustCompile(name + "=([^(;|$)]+)") + matches := re.FindStringSubmatch(ad.Cookie) + if len(matches) < 2 { + return "" + } + return matches[1] +} + +var config = driver.Config{ + Name: "NeteaseMusic", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &NeteaseMusic{} + }) +} diff --git a/drivers/netease_music/types.go b/drivers/netease_music/types.go new file mode 100644 index 0000000000000000000000000000000000000000..edbd40eed594b57e4ab2532900f2c5d6368da168 --- /dev/null +++ b/drivers/netease_music/types.go @@ -0,0 +1,116 @@ +package netease_music + +import ( + "context" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/sign" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/pkg/utils/random" + "github.com/alist-org/alist/v3/server/common" +) + +type HostsResp struct { + Upload []string `json:"upload"` +} + +type SongResp struct { + Data []struct { + Url string `json:"url"` + } `json:"data"` +} + +type ListResp struct { + Size string `json:"size"` + MaxSize string `json:"maxSize"` + Data []struct { + AddTime int64 `json:"addTime"` + FileName string `json:"fileName"` + FileSize int64 `json:"fileSize"` + SongId int64 `json:"songId"` + SimpleSong struct { + Al struct { + PicUrl string `json:"picUrl"` + } `json:"al"` + } `json:"simpleSong"` + } `json:"data"` +} + +type LyricObj struct { + model.Object + lyric string +} + +func (lrc *LyricObj) getProxyLink(args model.LinkArgs) *model.Link { + rawURL := common.GetApiUrl(args.HttpReq) + "/p" + lrc.Path + rawURL = utils.EncodePath(rawURL, true) + "?type=parsed&sign=" + sign.Sign(lrc.Path) + return &model.Link{URL: rawURL} +} + +func (lrc *LyricObj) getLyricLink() *model.Link { + reader := strings.NewReader(lrc.lyric) + return &model.Link{ + RangeReadCloser: &model.RangeReadCloser{ + RangeReader: func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + if httpRange.Length < 0 { + return io.NopCloser(reader), nil + } + sr := io.NewSectionReader(reader, httpRange.Start, httpRange.Length) + return io.NopCloser(sr), nil + }, + Closers: utils.EmptyClosers(), + }, + } +} + +type ReqOption struct { + crypto string + stream model.FileStreamer + data map[string]string + headers map[string]string + cookies []*http.Cookie + url string +} + +type Characteristic map[string]string + +func (ch *Characteristic) fromDriver(d *NeteaseMusic) *Characteristic { + *ch = map[string]string{ + "osver": "", + "deviceId": "", + "mobilename": "", + "appver": "6.1.1", + "versioncode": "140", + "buildver": strconv.FormatInt(time.Now().Unix(), 10), + "resolution": "1920x1080", + "os": "android", + "channel": "", + "requestId": strconv.FormatInt(time.Now().Unix()*1000, 10) + strconv.Itoa(int(random.RangeInt64(0, 1000))), + "MUSIC_U": d.musicU, + } + return ch +} + +func (ch Characteristic) toCookies() []*http.Cookie { + cookies := make([]*http.Cookie, 0) + for k, v := range ch { + cookies = append(cookies, &http.Cookie{Name: k, Value: v}) + } + return cookies +} + +func (ch *Characteristic) merge(data map[string]string) map[string]interface{} { + body := map[string]interface{}{ + "header": ch, + } + for k, v := range data { + body[k] = v + } + return body +} diff --git a/drivers/netease_music/upload.go b/drivers/netease_music/upload.go new file mode 100644 index 0000000000000000000000000000000000000000..ece496b36da00391637588a0947dad9470499563 --- /dev/null +++ b/drivers/netease_music/upload.go @@ -0,0 +1,208 @@ +package netease_music + +import ( + "crypto/md5" + "encoding/hex" + "io" + "net/http" + "strconv" + "strings" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/dhowden/tag" +) + +type token struct { + resourceId string + objectKey string + token string +} + +type songmeta struct { + needUpload bool + songId string + name string + artist string + album string +} + +type uploader struct { + driver *NeteaseMusic + file model.File + meta songmeta + md5 string + ext string + size string + filename string +} + +func (u *uploader) init(stream model.FileStreamer) error { + u.filename = stream.GetName() + u.size = strconv.FormatInt(stream.GetSize(), 10) + + u.ext = "mp3" + if strings.HasSuffix(stream.GetMimetype(), "flac") { + u.ext = "flac" + } + + h := md5.New() + io.Copy(h, stream) + u.md5 = hex.EncodeToString(h.Sum(nil)) + _, err := u.file.Seek(0, io.SeekStart) + if err != nil { + return err + } + + if m, err := tag.ReadFrom(u.file); err != nil { + u.meta = songmeta{} + } else { + u.meta = songmeta{ + name: m.Title(), + artist: m.Artist(), + album: m.Album(), + } + } + if u.meta.name == "" { + u.meta.name = u.filename + } + if u.meta.album == "" { + u.meta.album = "未知专辑" + } + if u.meta.artist == "" { + u.meta.artist = "未知艺术家" + } + _, err = u.file.Seek(0, io.SeekStart) + if err != nil { + return err + } + + return nil +} + +func (u *uploader) checkIfExisted() error { + body, err := u.driver.request("https://interface.music.163.com/api/cloud/upload/check", http.MethodPost, + ReqOption{ + crypto: "weapi", + data: map[string]string{ + "ext": "", + "songId": "0", + "version": "1", + "bitrate": "999000", + "length": u.size, + "md5": u.md5, + }, + cookies: []*http.Cookie{ + {Name: "os", Value: "pc"}, + {Name: "appver", Value: "2.9.7"}, + }, + }, + ) + if err != nil { + return err + } + + u.meta.songId = utils.Json.Get(body, "songId").ToString() + u.meta.needUpload = utils.Json.Get(body, "needUpload").ToBool() + + return nil +} + +func (u *uploader) allocToken(bucket ...string) (token, error) { + if len(bucket) == 0 { + bucket = []string{""} + } + + body, err := u.driver.request("https://music.163.com/weapi/nos/token/alloc", http.MethodPost, ReqOption{ + crypto: "weapi", + data: map[string]string{ + "bucket": bucket[0], + "local": "false", + "type": "audio", + "nos_product": "3", + "filename": u.filename, + "md5": u.md5, + "ext": u.ext, + }, + }) + if err != nil { + return token{}, err + } + + return token{ + resourceId: utils.Json.Get(body, "result", "resourceId").ToString(), + objectKey: utils.Json.Get(body, "result", "objectKey").ToString(), + token: utils.Json.Get(body, "result", "token").ToString(), + }, nil +} + +func (u *uploader) publishInfo(resourceId string) error { + body, err := u.driver.request("https://music.163.com/api/upload/cloud/info/v2", http.MethodPost, ReqOption{ + crypto: "weapi", + data: map[string]string{ + "md5": u.md5, + "filename": u.filename, + "song": u.meta.name, + "album": u.meta.album, + "artist": u.meta.artist, + "songid": u.meta.songId, + "resourceId": resourceId, + "bitrate": "999000", + }, + }) + if err != nil { + return err + } + + _, err = u.driver.request("https://interface.music.163.com/api/cloud/pub/v2", http.MethodPost, ReqOption{ + crypto: "weapi", + data: map[string]string{ + "songid": utils.Json.Get(body, "songId").ToString(), + }, + }) + if err != nil { + return err + } + + return nil +} + +func (u *uploader) upload(stream model.FileStreamer) error { + bucket := "jd-musicrep-privatecloud-audio-public" + token, err := u.allocToken(bucket) + if err != nil { + return err + } + + body, err := u.driver.request("https://wanproxy.127.net/lbs?version=1.0&bucketname="+bucket, http.MethodGet, + ReqOption{}, + ) + if err != nil { + return err + } + var resp HostsResp + err = utils.Json.Unmarshal(body, &resp) + if err != nil { + return err + } + + objectKey := strings.ReplaceAll(token.objectKey, "/", "%2F") + _, err = u.driver.request( + resp.Upload[0]+"/"+bucket+"/"+objectKey+"?offset=0&complete=true&version=1.0", + http.MethodPost, + ReqOption{ + stream: stream, + headers: map[string]string{ + "x-nos-token": token.token, + "Content-Type": "audio/mpeg", + "Content-Length": u.size, + "Content-MD5": u.md5, + }, + }, + ) + if err != nil { + return err + } + + return nil +} diff --git a/drivers/netease_music/util.go b/drivers/netease_music/util.go new file mode 100644 index 0000000000000000000000000000000000000000..4d0696eb82be3e0efab7c0d9b606b229580628a9 --- /dev/null +++ b/drivers/netease_music/util.go @@ -0,0 +1,246 @@ +package netease_music + +import ( + "io" + "net/http" + "path" + "regexp" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" +) + +func (d *NeteaseMusic) request(url, method string, opt ReqOption) ([]byte, error) { + req := base.RestyClient.R() + + req.SetHeader("Cookie", d.Addition.Cookie) + + if strings.Contains(url, "music.163.com") { + req.SetHeader("Referer", "https://music.163.com") + } + + if opt.cookies != nil { + for _, cookie := range opt.cookies { + req.SetCookie(cookie) + } + } + + if opt.headers != nil { + for header, value := range opt.headers { + req.SetHeader(header, value) + } + } + + data := opt.data + if opt.crypto == "weapi" { + data = weapi(data) + re, _ := regexp.Compile(`/\w*api/`) + url = re.ReplaceAllString(url, "/weapi/") + } else if opt.crypto == "eapi" { + ch := new(Characteristic).fromDriver(d) + req.SetCookies(ch.toCookies()) + data = eapi(opt.url, ch.merge(data)) + re, _ := regexp.Compile(`/\w*api/`) + url = re.ReplaceAllString(url, "/eapi/") + } else if opt.crypto == "linuxapi" { + re, _ := regexp.Compile(`/\w*api/`) + data = linuxapi(map[string]interface{}{ + "url": re.ReplaceAllString(url, "/api/"), + "method": method, + "params": data, + }) + req.Header.Set("User-Agent", "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/60.0.3112.90 Safari/537.36") + url = "https://music.163.com/api/linux/forward" + } + + if method == http.MethodPost { + if opt.stream != nil { + req.SetContentLength(true) + req.SetBody(io.ReadCloser(opt.stream)) + } else { + req.SetFormData(data) + } + res, err := req.Post(url) + return res.Body(), err + } + + if method == http.MethodGet { + res, err := req.Get(url) + return res.Body(), err + } + + return nil, errs.NotImplement +} + +func (d *NeteaseMusic) getSongObjs(args model.ListArgs) ([]model.Obj, error) { + body, err := d.request("https://music.163.com/weapi/v1/cloud/get", http.MethodPost, ReqOption{ + crypto: "weapi", + data: map[string]string{ + "limit": strconv.FormatUint(d.Addition.SongLimit, 10), + "offset": "0", + }, + cookies: []*http.Cookie{ + {Name: "os", Value: "pc"}, + }, + }) + if err != nil { + return nil, err + } + + var resp ListResp + err = utils.Json.Unmarshal(body, &resp) + if err != nil { + return nil, err + } + + d.fileMapByName = make(map[string]model.Obj) + files := make([]model.Obj, 0, len(resp.Data)) + for _, f := range resp.Data { + song := &model.ObjThumb{ + Object: model.Object{ + IsFolder: false, + Size: f.FileSize, + Name: f.FileName, + Modified: time.UnixMilli(f.AddTime), + ID: strconv.FormatInt(f.SongId, 10), + }, + Thumbnail: model.Thumbnail{Thumbnail: f.SimpleSong.Al.PicUrl}, + } + d.fileMapByName[song.Name] = song + files = append(files, song) + + // map song id for lyric + lrcName := strings.Split(f.FileName, ".")[0] + ".lrc" + lrc := &model.Object{ + IsFolder: false, + Name: lrcName, + Path: path.Join(args.ReqPath, lrcName), + ID: strconv.FormatInt(f.SongId, 10), + } + d.fileMapByName[lrc.Name] = lrc + } + + return files, nil +} + +func (d *NeteaseMusic) getSongLink(file model.Obj) (*model.Link, error) { + body, err := d.request( + "https://music.163.com/api/song/enhance/player/url", http.MethodPost, ReqOption{ + crypto: "linuxapi", + data: map[string]string{ + "ids": "[" + file.GetID() + "]", + "br": "999000", + }, + cookies: []*http.Cookie{ + {Name: "os", Value: "pc"}, + }, + }, + ) + if err != nil { + return nil, err + } + + var resp SongResp + err = utils.Json.Unmarshal(body, &resp) + if err != nil { + return nil, err + } + + if len(resp.Data) < 1 { + return nil, errs.ObjectNotFound + } + + return &model.Link{URL: resp.Data[0].Url}, nil +} + +func (d *NeteaseMusic) getLyricObj(file model.Obj) (model.Obj, error) { + if lrc, ok := file.(*LyricObj); ok { + return lrc, nil + } + + body, err := d.request( + "https://music.163.com/api/song/lyric?_nmclfl=1", http.MethodPost, ReqOption{ + data: map[string]string{ + "id": file.GetID(), + "tv": "-1", + "lv": "-1", + "rv": "-1", + "kv": "-1", + }, + cookies: []*http.Cookie{ + {Name: "os", Value: "ios"}, + }, + }, + ) + if err != nil { + return nil, err + } + + lyric := utils.Json.Get(body, "lrc", "lyric").ToString() + + return &LyricObj{ + lyric: lyric, + Object: model.Object{ + IsFolder: false, + ID: file.GetID(), + Name: file.GetName(), + Path: file.GetPath(), + Size: int64(len(lyric)), + }, + }, nil +} + +func (d *NeteaseMusic) removeSongObj(file model.Obj) error { + _, err := d.request("http://music.163.com/weapi/cloud/del", http.MethodPost, ReqOption{ + crypto: "weapi", + data: map[string]string{ + "songIds": "[" + file.GetID() + "]", + }, + }) + + return err +} + +func (d *NeteaseMusic) putSongStream(stream model.FileStreamer) error { + tmp, err := stream.CacheFullInTempFile() + if err != nil { + return err + } + defer tmp.Close() + + u := uploader{driver: d, file: tmp} + + err = u.init(stream) + if err != nil { + return err + } + + err = u.checkIfExisted() + if err != nil { + return err + } + + token, err := u.allocToken() + if err != nil { + return err + } + + if u.meta.needUpload { + err = u.upload(stream) + if err != nil { + return err + } + } + + err = u.publishInfo(token.resourceId) + if err != nil { + return err + } + + return nil +} diff --git a/drivers/onedrive/driver.go b/drivers/onedrive/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..f134788768c30b939921609a32a4eb589f5f0303 --- /dev/null +++ b/drivers/onedrive/driver.go @@ -0,0 +1,221 @@ +package onedrive + +import ( + "context" + "fmt" + "net/http" + "net/url" + "path" + "strings" + "sync" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +type Onedrive struct { + model.Storage + Addition + AccessToken string + root *Object + mutex sync.Mutex +} + +func (d *Onedrive) Config() driver.Config { + return config +} + +func (d *Onedrive) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Onedrive) Init(ctx context.Context) error { + if d.ChunkSize < 1 { + d.ChunkSize = 5 + } + return d.refreshToken() +} + +func (d *Onedrive) Drop(ctx context.Context) error { + return nil +} + +func (d *Onedrive) GetRoot(ctx context.Context) (model.Obj, error) { + if d.root != nil { + return d.root, nil + } + d.mutex.Lock() + defer d.mutex.Unlock() + root := &Object{ + ObjThumb: model.ObjThumb{ + Object: model.Object{ + ID: "root", + Path: d.RootFolderPath, + Name: "root", + Size: 0, + Modified: d.Modified, + Ctime: d.Modified, + IsFolder: true, + }, + }, + ParentID: "", + } + if !utils.PathEqual(d.RootFolderPath, "/") { + // get root folder id + url := d.GetMetaUrl(false, d.RootFolderPath) + var resp struct { + Id string `json:"id"` + } + _, err := d.Request(url, http.MethodGet, nil, &resp) + if err != nil { + return nil, err + } + root.ID = resp.Id + } + d.root = root + return d.root, nil +} + +func (d *Onedrive) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(dir.GetPath()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src, dir.GetID()), nil + }) +} + +func (d *Onedrive) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + f, err := d.GetFile(file.GetPath()) + if err != nil { + return nil, err + } + if f.File == nil { + return nil, errs.NotFile + } + u := f.Url + if d.CustomHost != "" { + _u, err := url.Parse(f.Url) + if err != nil { + return nil, err + } + _u.Host = d.CustomHost + u = _u.String() + } + + if d.ProxyUrl != "" { + + if strings.HasSuffix(d.ProxyUrl, "/") { + u = d.ProxyUrl + f.Url + } else { + u = d.ProxyUrl + "/" + f.Url + } + + } + + return &model.Link{ + URL: u, + }, nil +} + +func (d *Onedrive) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + url := d.GetMetaUrl(false, parentDir.GetPath()) + "/children" + data := base.Json{ + "name": dirName, + "folder": base.Json{}, + "@microsoft.graph.conflictBehavior": "rename", + } + _, err := d.Request(url, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *Onedrive) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + parentPath := "" + if dstDir.GetID() == "" { + parentPath = dstDir.GetPath() + if utils.PathEqual(parentPath, "/") { + parentPath = path.Join("/drive/root", parentPath) + } else { + parentPath = path.Join("/drive/root:/", parentPath) + } + } + data := base.Json{ + "parentReference": base.Json{ + "id": dstDir.GetID(), + "path": parentPath, + }, + "name": srcObj.GetName(), + } + url := d.GetMetaUrl(false, srcObj.GetPath()) + _, err := d.Request(url, http.MethodPatch, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *Onedrive) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + var parentID string + if o, ok := srcObj.(*Object); ok { + parentID = o.ParentID + } else { + return fmt.Errorf("srcObj is not Object") + } + if parentID == "" { + parentID = "root" + } + data := base.Json{ + "parentReference": base.Json{ + "id": parentID, + }, + "name": newName, + } + url := d.GetMetaUrl(false, srcObj.GetPath()) + _, err := d.Request(url, http.MethodPatch, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *Onedrive) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + dst, err := d.GetFile(dstDir.GetPath()) + if err != nil { + return err + } + data := base.Json{ + "parentReference": base.Json{ + "driveId": dst.ParentReference.DriveId, + "id": dst.Id, + }, + "name": srcObj.GetName(), + } + url := d.GetMetaUrl(false, srcObj.GetPath()) + "/copy" + _, err = d.Request(url, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *Onedrive) Remove(ctx context.Context, obj model.Obj) error { + url := d.GetMetaUrl(false, obj.GetPath()) + _, err := d.Request(url, http.MethodDelete, nil, nil) + return err +} + +func (d *Onedrive) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + var err error + if stream.GetSize() <= 4*1024*1024 { + err = d.upSmall(ctx, dstDir, stream) + } else { + err = d.upBig(ctx, dstDir, stream, up) + } + return err +} + +var _ driver.Driver = (*Onedrive)(nil) diff --git a/drivers/onedrive/meta.go b/drivers/onedrive/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..a750522f6e19e20319f8d995985465644d35fe61 --- /dev/null +++ b/drivers/onedrive/meta.go @@ -0,0 +1,32 @@ +package onedrive + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + Region string `json:"region" type:"select" required:"true" options:"global,cn,us,de" default:"global"` + IsSharepoint bool `json:"is_sharepoint"` + ClientID string `json:"client_id" required:"true"` + ClientSecret string `json:"client_secret" required:"true"` + RedirectUri string `json:"redirect_uri" required:"true" default:"https://alist.nn.ci/tool/onedrive/callback"` + RefreshToken string `json:"refresh_token" required:"true"` + SiteId string `json:"site_id"` + ChunkSize int64 `json:"chunk_size" type:"number" default:"5"` + CustomHost string `json:"custom_host" help:"Custom host for onedrive download link"` + ProxyUrl string `json:"proxy_url" help:"ProxyUrl for onedrive download link like pikpak"` +} + +var config = driver.Config{ + Name: "Onedrive", + LocalSort: true, + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Onedrive{} + }) +} diff --git a/drivers/onedrive/types.go b/drivers/onedrive/types.go new file mode 100644 index 0000000000000000000000000000000000000000..69264abcf03d16612f94a0944e66a2efe6ceb89b --- /dev/null +++ b/drivers/onedrive/types.go @@ -0,0 +1,74 @@ +package onedrive + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type Host struct { + Oauth string + Api string +} + +type TokenErr struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +type RespErr struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` +} + +type File struct { + Id string `json:"id"` + Name string `json:"name"` + Size int64 `json:"size"` + LastModifiedDateTime time.Time `json:"lastModifiedDateTime"` + Url string `json:"@microsoft.graph.downloadUrl"` + File *struct { + MimeType string `json:"mimeType"` + } `json:"file"` + Thumbnails []struct { + Medium struct { + Url string `json:"url"` + } `json:"medium"` + } `json:"thumbnails"` + ParentReference struct { + DriveId string `json:"driveId"` + } `json:"parentReference"` +} + +type Object struct { + model.ObjThumb + ParentID string +} + +func fileToObj(f File, parentID string) *Object { + thumb := "" + if len(f.Thumbnails) > 0 { + thumb = f.Thumbnails[0].Medium.Url + } + return &Object{ + ObjThumb: model.ObjThumb{ + Object: model.Object{ + ID: f.Id, + Name: f.Name, + Size: f.Size, + Modified: f.LastModifiedDateTime, + IsFolder: f.File == nil, + }, + Thumbnail: model.Thumbnail{Thumbnail: thumb}, + //Url: model.Url{Url: f.Url}, + }, + ParentID: parentID, + } +} + +type Files struct { + Value []File `json:"value"` + NextLink string `json:"@odata.nextLink"` +} diff --git a/drivers/onedrive/util.go b/drivers/onedrive/util.go new file mode 100644 index 0000000000000000000000000000000000000000..a0c6fa8fcbfdbfc985291732a8da25a3dc0549da --- /dev/null +++ b/drivers/onedrive/util.go @@ -0,0 +1,209 @@ +package onedrive + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + stdpath "path" + "strconv" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + jsoniter "github.com/json-iterator/go" + log "github.com/sirupsen/logrus" +) + +var onedriveHostMap = map[string]Host{ + "global": { + Oauth: "https://login.microsoftonline.com", + Api: "https://graph.microsoft.com", + }, + "cn": { + Oauth: "https://login.chinacloudapi.cn", + Api: "https://microsoftgraph.chinacloudapi.cn", + }, + "us": { + Oauth: "https://login.microsoftonline.us", + Api: "https://graph.microsoft.us", + }, + "de": { + Oauth: "https://login.microsoftonline.de", + Api: "https://graph.microsoft.de", + }, +} + +func (d *Onedrive) GetMetaUrl(auth bool, path string) string { + host, _ := onedriveHostMap[d.Region] + path = utils.EncodePath(path, true) + if auth { + return host.Oauth + } + if d.IsSharepoint { + if path == "/" || path == "\\" { + return fmt.Sprintf("%s/v1.0/sites/%s/drive/root", host.Api, d.SiteId) + } else { + return fmt.Sprintf("%s/v1.0/sites/%s/drive/root:%s:", host.Api, d.SiteId, path) + } + } else { + if path == "/" || path == "\\" { + return fmt.Sprintf("%s/v1.0/me/drive/root", host.Api) + } else { + return fmt.Sprintf("%s/v1.0/me/drive/root:%s:", host.Api, path) + } + } +} + +func (d *Onedrive) refreshToken() error { + var err error + for i := 0; i < 3; i++ { + err = d._refreshToken() + if err == nil { + break + } + } + return err +} + +func (d *Onedrive) _refreshToken() error { + url := d.GetMetaUrl(true, "") + "/common/oauth2/v2.0/token" + var resp base.TokenResp + var e TokenErr + _, err := base.RestyClient.R().SetResult(&resp).SetError(&e).SetFormData(map[string]string{ + "grant_type": "refresh_token", + "client_id": d.ClientID, + "client_secret": d.ClientSecret, + "redirect_uri": d.RedirectUri, + "refresh_token": d.RefreshToken, + }).Post(url) + if err != nil { + return err + } + if e.Error != "" { + return fmt.Errorf("%s", e.ErrorDescription) + } + if resp.RefreshToken == "" { + return errs.EmptyToken + } + d.RefreshToken, d.AccessToken = resp.RefreshToken, resp.AccessToken + op.MustSaveDriverStorage(d) + return nil +} + +func (d *Onedrive) Request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeader("Authorization", "Bearer "+d.AccessToken) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e RespErr + req.SetError(&e) + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + if e.Error.Code != "" { + if e.Error.Code == "InvalidAuthenticationToken" { + err = d.refreshToken() + if err != nil { + return nil, err + } + return d.Request(url, method, callback, resp) + } + return nil, errors.New(e.Error.Message) + } + return res.Body(), nil +} + +func (d *Onedrive) getFiles(path string) ([]File, error) { + var res []File + nextLink := d.GetMetaUrl(false, path) + "/children?$top=5000&$expand=thumbnails($select=medium)&$select=id,name,size,lastModifiedDateTime,content.downloadUrl,file,parentReference" + for nextLink != "" { + var files Files + _, err := d.Request(nextLink, http.MethodGet, nil, &files) + if err != nil { + return nil, err + } + res = append(res, files.Value...) + nextLink = files.NextLink + } + return res, nil +} + +func (d *Onedrive) GetFile(path string) (*File, error) { + var file File + u := d.GetMetaUrl(false, path) + _, err := d.Request(u, http.MethodGet, nil, &file) + return &file, err +} + +func (d *Onedrive) upSmall(ctx context.Context, dstDir model.Obj, stream model.FileStreamer) error { + url := d.GetMetaUrl(false, stdpath.Join(dstDir.GetPath(), stream.GetName())) + "/content" + data, err := io.ReadAll(stream) + if err != nil { + return err + } + _, err = d.Request(url, http.MethodPut, func(req *resty.Request) { + req.SetBody(data).SetContext(ctx) + }, nil) + return err +} + +func (d *Onedrive) upBig(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + url := d.GetMetaUrl(false, stdpath.Join(dstDir.GetPath(), stream.GetName())) + "/createUploadSession" + res, err := d.Request(url, http.MethodPost, nil, nil) + if err != nil { + return err + } + uploadUrl := jsoniter.Get(res, "uploadUrl").ToString() + var finish int64 = 0 + DEFAULT := d.ChunkSize * 1024 * 1024 + for finish < stream.GetSize() { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + log.Debugf("upload: %d", finish) + var byteSize int64 = DEFAULT + left := stream.GetSize() - finish + if left < DEFAULT { + byteSize = left + } + byteData := make([]byte, byteSize) + n, err := io.ReadFull(stream, byteData) + log.Debug(err, n) + if err != nil { + return err + } + req, err := http.NewRequest("PUT", uploadUrl, bytes.NewBuffer(byteData)) + if err != nil { + return err + } + req = req.WithContext(ctx) + req.Header.Set("Content-Length", strconv.Itoa(int(byteSize))) + req.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", finish, finish+byteSize-1, stream.GetSize())) + finish += byteSize + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + // https://learn.microsoft.com/zh-cn/onedrive/developer/rest-api/api/driveitem_createuploadsession + if res.StatusCode != 201 && res.StatusCode != 202 && res.StatusCode != 200 { + data, _ := io.ReadAll(res.Body) + res.Body.Close() + return errors.New(string(data)) + } + res.Body.Close() + up(float64(finish) * 100 / float64(stream.GetSize())) + } + return nil +} diff --git a/drivers/onedrive_app/driver.go b/drivers/onedrive_app/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..8a924341bb7bc6eecf6ddd135914dd0124567118 --- /dev/null +++ b/drivers/onedrive_app/driver.go @@ -0,0 +1,209 @@ +package onedrive_app + +import ( + "context" + "fmt" + "net/http" + "net/url" + "path" + "sync" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +type OnedriveAPP struct { + model.Storage + Addition + AccessToken string + root *Object + mutex sync.Mutex +} + +func (d *OnedriveAPP) Config() driver.Config { + return config +} + +func (d *OnedriveAPP) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *OnedriveAPP) Init(ctx context.Context) error { + if d.ChunkSize < 1 { + d.ChunkSize = 5 + } + return d.accessToken() +} + +func (d *OnedriveAPP) Drop(ctx context.Context) error { + return nil +} + +func (d *OnedriveAPP) GetRoot(ctx context.Context) (model.Obj, error) { + if d.root != nil { + return d.root, nil + } + d.mutex.Lock() + defer d.mutex.Unlock() + root := &Object{ + ObjThumb: model.ObjThumb{ + Object: model.Object{ + ID: "root", + Path: d.RootFolderPath, + Name: "root", + Size: 0, + Modified: d.Modified, + Ctime: d.Modified, + IsFolder: true, + }, + }, + ParentID: "", + } + if !utils.PathEqual(d.RootFolderPath, "/") { + // get root folder id + url := d.GetMetaUrl(false, d.RootFolderPath) + var resp struct { + Id string `json:"id"` + } + _, err := d.Request(url, http.MethodGet, nil, &resp) + if err != nil { + return nil, err + } + root.ID = resp.Id + } + d.root = root + return d.root, nil +} + +func (d *OnedriveAPP) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(dir.GetPath()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src, dir.GetID()), nil + }) +} + +func (d *OnedriveAPP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + f, err := d.GetFile(file.GetPath()) + if err != nil { + return nil, err + } + if f.File == nil { + return nil, errs.NotFile + } + u := f.Url + if d.CustomHost != "" { + _u, err := url.Parse(f.Url) + if err != nil { + return nil, err + } + _u.Host = d.CustomHost + u = _u.String() + } + return &model.Link{ + URL: u, + }, nil +} + +func (d *OnedriveAPP) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + url := d.GetMetaUrl(false, parentDir.GetPath()) + "/children" + data := base.Json{ + "name": dirName, + "folder": base.Json{}, + "@microsoft.graph.conflictBehavior": "rename", + } + _, err := d.Request(url, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *OnedriveAPP) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + parentPath := "" + if dstDir.GetID() == "" { + parentPath = dstDir.GetPath() + if utils.PathEqual(parentPath, "/") { + parentPath = path.Join("/drive/root", parentPath) + } else { + parentPath = path.Join("/drive/root:/", parentPath) + } + } + data := base.Json{ + "parentReference": base.Json{ + "id": dstDir.GetID(), + "path": parentPath, + }, + "name": srcObj.GetName(), + } + url := d.GetMetaUrl(false, srcObj.GetPath()) + _, err := d.Request(url, http.MethodPatch, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *OnedriveAPP) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + var parentID string + if o, ok := srcObj.(*Object); ok { + parentID = o.ParentID + } else { + return fmt.Errorf("srcObj is not Object") + } + if parentID == "" { + parentID = "root" + } + data := base.Json{ + "parentReference": base.Json{ + "id": parentID, + }, + "name": newName, + } + url := d.GetMetaUrl(false, srcObj.GetPath()) + _, err := d.Request(url, http.MethodPatch, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *OnedriveAPP) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + dst, err := d.GetFile(dstDir.GetPath()) + if err != nil { + return err + } + data := base.Json{ + "parentReference": base.Json{ + "driveId": dst.ParentReference.DriveId, + "id": dst.Id, + }, + "name": srcObj.GetName(), + } + url := d.GetMetaUrl(false, srcObj.GetPath()) + "/copy" + _, err = d.Request(url, http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *OnedriveAPP) Remove(ctx context.Context, obj model.Obj) error { + url := d.GetMetaUrl(false, obj.GetPath()) + _, err := d.Request(url, http.MethodDelete, nil, nil) + return err +} + +func (d *OnedriveAPP) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + var err error + if stream.GetSize() <= 4*1024*1024 { + err = d.upSmall(ctx, dstDir, stream) + } else { + err = d.upBig(ctx, dstDir, stream, up) + } + return err +} + +var _ driver.Driver = (*OnedriveAPP)(nil) diff --git a/drivers/onedrive_app/meta.go b/drivers/onedrive_app/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..0499f503bad334597d0b0021cdd1439a3dde36ae --- /dev/null +++ b/drivers/onedrive_app/meta.go @@ -0,0 +1,29 @@ +package onedrive_app + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + Region string `json:"region" type:"select" required:"true" options:"global,cn,us,de" default:"global"` + ClientID string `json:"client_id" required:"true"` + ClientSecret string `json:"client_secret" required:"true"` + TenantID string `json:"tenant_id"` + Email string `json:"email"` + ChunkSize int64 `json:"chunk_size" type:"number" default:"5"` + CustomHost string `json:"custom_host" help:"Custom host for onedrive download link"` +} + +var config = driver.Config{ + Name: "OnedriveAPP", + LocalSort: true, + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &OnedriveAPP{} + }) +} diff --git a/drivers/onedrive_app/types.go b/drivers/onedrive_app/types.go new file mode 100644 index 0000000000000000000000000000000000000000..7179e4b450ceded97db2415d2fa41dd79a39ce38 --- /dev/null +++ b/drivers/onedrive_app/types.go @@ -0,0 +1,74 @@ +package onedrive_app + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type Host struct { + Oauth string + Api string +} + +type TokenErr struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +type RespErr struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` +} + +type File struct { + Id string `json:"id"` + Name string `json:"name"` + Size int64 `json:"size"` + LastModifiedDateTime time.Time `json:"lastModifiedDateTime"` + Url string `json:"@microsoft.graph.downloadUrl"` + File *struct { + MimeType string `json:"mimeType"` + } `json:"file"` + Thumbnails []struct { + Medium struct { + Url string `json:"url"` + } `json:"medium"` + } `json:"thumbnails"` + ParentReference struct { + DriveId string `json:"driveId"` + } `json:"parentReference"` +} + +type Object struct { + model.ObjThumb + ParentID string +} + +func fileToObj(f File, parentID string) *Object { + thumb := "" + if len(f.Thumbnails) > 0 { + thumb = f.Thumbnails[0].Medium.Url + } + return &Object{ + ObjThumb: model.ObjThumb{ + Object: model.Object{ + ID: f.Id, + Name: f.Name, + Size: f.Size, + Modified: f.LastModifiedDateTime, + IsFolder: f.File == nil, + }, + Thumbnail: model.Thumbnail{Thumbnail: thumb}, + //Url: model.Url{Url: f.Url}, + }, + ParentID: parentID, + } +} + +type Files struct { + Value []File `json:"value"` + NextLink string `json:"@odata.nextLink"` +} diff --git a/drivers/onedrive_app/util.go b/drivers/onedrive_app/util.go new file mode 100644 index 0000000000000000000000000000000000000000..28b34837806e828ccc70c220ab0155bacc80d4eb --- /dev/null +++ b/drivers/onedrive_app/util.go @@ -0,0 +1,200 @@ +package onedrive_app + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + stdpath "path" + "strconv" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + jsoniter "github.com/json-iterator/go" + log "github.com/sirupsen/logrus" +) + +var onedriveHostMap = map[string]Host{ + "global": { + Oauth: "https://login.microsoftonline.com", + Api: "https://graph.microsoft.com", + }, + "cn": { + Oauth: "https://login.chinacloudapi.cn", + Api: "https://microsoftgraph.chinacloudapi.cn", + }, + "us": { + Oauth: "https://login.microsoftonline.us", + Api: "https://graph.microsoft.us", + }, + "de": { + Oauth: "https://login.microsoftonline.de", + Api: "https://graph.microsoft.de", + }, +} + +func (d *OnedriveAPP) GetMetaUrl(auth bool, path string) string { + host, _ := onedriveHostMap[d.Region] + path = utils.EncodePath(path, true) + if auth { + return host.Oauth + } + if path == "/" || path == "\\" { + return fmt.Sprintf("%s/v1.0/users/%s/drive/root", host.Api, d.Email) + } + return fmt.Sprintf("%s/v1.0/users/%s/drive/root:%s:", host.Api, d.Email, path) +} + +func (d *OnedriveAPP) accessToken() error { + var err error + for i := 0; i < 3; i++ { + err = d._accessToken() + if err == nil { + break + } + } + return err +} + +func (d *OnedriveAPP) _accessToken() error { + url := d.GetMetaUrl(true, "") + "/" + d.TenantID + "/oauth2/token" + var resp base.TokenResp + var e TokenErr + _, err := base.RestyClient.R().SetResult(&resp).SetError(&e).SetFormData(map[string]string{ + "grant_type": "client_credentials", + "client_id": d.ClientID, + "client_secret": d.ClientSecret, + "resource": onedriveHostMap[d.Region].Api + "/", + "scope": onedriveHostMap[d.Region].Api + "/.default", + }).Post(url) + if err != nil { + return err + } + if e.Error != "" { + return fmt.Errorf("%s", e.ErrorDescription) + } + if resp.AccessToken == "" { + return errs.EmptyToken + } + d.AccessToken = resp.AccessToken + op.MustSaveDriverStorage(d) + return nil +} + +func (d *OnedriveAPP) Request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeader("Authorization", "Bearer "+d.AccessToken) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e RespErr + req.SetError(&e) + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + if e.Error.Code != "" { + if e.Error.Code == "InvalidAuthenticationToken" { + err = d.accessToken() + if err != nil { + return nil, err + } + return d.Request(url, method, callback, resp) + } + return nil, errors.New(e.Error.Message) + } + return res.Body(), nil +} + +func (d *OnedriveAPP) getFiles(path string) ([]File, error) { + var res []File + nextLink := d.GetMetaUrl(false, path) + "/children?$top=5000&$expand=thumbnails($select=medium)&$select=id,name,size,lastModifiedDateTime,content.downloadUrl,file,parentReference" + for nextLink != "" { + var files Files + _, err := d.Request(nextLink, http.MethodGet, nil, &files) + if err != nil { + return nil, err + } + res = append(res, files.Value...) + nextLink = files.NextLink + } + return res, nil +} + +func (d *OnedriveAPP) GetFile(path string) (*File, error) { + var file File + u := d.GetMetaUrl(false, path) + _, err := d.Request(u, http.MethodGet, nil, &file) + return &file, err +} + +func (d *OnedriveAPP) upSmall(ctx context.Context, dstDir model.Obj, stream model.FileStreamer) error { + url := d.GetMetaUrl(false, stdpath.Join(dstDir.GetPath(), stream.GetName())) + "/content" + data, err := io.ReadAll(stream) + if err != nil { + return err + } + _, err = d.Request(url, http.MethodPut, func(req *resty.Request) { + req.SetBody(data).SetContext(ctx) + }, nil) + return err +} + +func (d *OnedriveAPP) upBig(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + url := d.GetMetaUrl(false, stdpath.Join(dstDir.GetPath(), stream.GetName())) + "/createUploadSession" + res, err := d.Request(url, http.MethodPost, nil, nil) + if err != nil { + return err + } + uploadUrl := jsoniter.Get(res, "uploadUrl").ToString() + var finish int64 = 0 + DEFAULT := d.ChunkSize * 1024 * 1024 + for finish < stream.GetSize() { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + log.Debugf("upload: %d", finish) + var byteSize int64 = DEFAULT + left := stream.GetSize() - finish + if left < DEFAULT { + byteSize = left + } + byteData := make([]byte, byteSize) + n, err := io.ReadFull(stream, byteData) + log.Debug(err, n) + if err != nil { + return err + } + req, err := http.NewRequest("PUT", uploadUrl, bytes.NewBuffer(byteData)) + if err != nil { + return err + } + req = req.WithContext(ctx) + req.Header.Set("Content-Length", strconv.Itoa(int(byteSize))) + req.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", finish, finish+byteSize-1, stream.GetSize())) + finish += byteSize + res, err := base.HttpClient.Do(req) + if err != nil { + return err + } + // https://learn.microsoft.com/zh-cn/onedrive/developer/rest-api/api/driveitem_createuploadsession + if res.StatusCode != 201 && res.StatusCode != 202 && res.StatusCode != 200 { + data, _ := io.ReadAll(res.Body) + res.Body.Close() + return errors.New(string(data)) + } + res.Body.Close() + up(float64(finish) * 100 / float64(stream.GetSize())) + } + return nil +} diff --git a/drivers/onedrive_sharelink/driver.go b/drivers/onedrive_sharelink/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..1282409e4b718f09b1f8a2e34024b069f2cbcf66 --- /dev/null +++ b/drivers/onedrive_sharelink/driver.go @@ -0,0 +1,131 @@ +package onedrive_sharelink + +import ( + "context" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/cron" + "github.com/alist-org/alist/v3/pkg/utils" + log "github.com/sirupsen/logrus" +) + +type OnedriveSharelink struct { + model.Storage + cron *cron.Cron + Addition +} + +func (d *OnedriveSharelink) Config() driver.Config { + return config +} + +func (d *OnedriveSharelink) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *OnedriveSharelink) Init(ctx context.Context) error { + // Initialize error variable + var err error + + // If there is "-my" in the URL, it is NOT a SharePoint link + d.IsSharepoint = !strings.Contains(d.ShareLinkURL, "-my") + + // Initialize cron job to run every hour + d.cron = cron.NewCron(time.Hour * 1) + d.cron.Do(func() { + var err error + d.Headers, err = d.getHeaders() + if err != nil { + log.Errorf("%+v", err) + } + }) + + // Get initial headers + d.Headers, err = d.getHeaders() + if err != nil { + return err + } + + return nil +} + +func (d *OnedriveSharelink) Drop(ctx context.Context) error { + return nil +} + +func (d *OnedriveSharelink) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + path := dir.GetPath() + files, err := d.getFiles(path) + if err != nil { + return nil, err + } + + // Convert the slice of files to the required model.Obj format + return utils.SliceConvert(files, func(src Item) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *OnedriveSharelink) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + // Get the unique ID of the file + uniqueId := file.GetID() + // Cut the first char and the last char + uniqueId = uniqueId[1 : len(uniqueId)-1] + url := d.downloadLinkPrefix + uniqueId + header := d.Headers + + // If the headers are older than 30 minutes, get new headers + if d.HeaderTime < time.Now().Unix()-1800 { + var err error + log.Debug("headers are older than 30 minutes, get new headers") + header, err = d.getHeaders() + if err != nil { + return nil, err + } + } + + return &model.Link{ + URL: url, + Header: header, + }, nil +} + +func (d *OnedriveSharelink) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + // TODO create folder, optional + return errs.NotImplement +} + +func (d *OnedriveSharelink) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + // TODO move obj, optional + return errs.NotImplement +} + +func (d *OnedriveSharelink) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + // TODO rename obj, optional + return errs.NotImplement +} + +func (d *OnedriveSharelink) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + // TODO copy obj, optional + return errs.NotImplement +} + +func (d *OnedriveSharelink) Remove(ctx context.Context, obj model.Obj) error { + // TODO remove obj, optional + return errs.NotImplement +} + +func (d *OnedriveSharelink) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + // TODO upload file, optional + return errs.NotImplement +} + +//func (d *OnedriveSharelink) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*OnedriveSharelink)(nil) diff --git a/drivers/onedrive_sharelink/meta.go b/drivers/onedrive_sharelink/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..6f1ccfc459178c98f24c1dc6f7488ca639c4e797 --- /dev/null +++ b/drivers/onedrive_sharelink/meta.go @@ -0,0 +1,32 @@ +package onedrive_sharelink + +import ( + "net/http" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + ShareLinkURL string `json:"url" required:"true"` + ShareLinkPassword string `json:"password"` + IsSharepoint bool + downloadLinkPrefix string + Headers http.Header + HeaderTime int64 +} + +var config = driver.Config{ + Name: "Onedrive Sharelink", + OnlyProxy: true, + NoUpload: true, + DefaultRoot: "/", + CheckStatus: false, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &OnedriveSharelink{} + }) +} diff --git a/drivers/onedrive_sharelink/types.go b/drivers/onedrive_sharelink/types.go new file mode 100644 index 0000000000000000000000000000000000000000..2433425026396caaebd1e3d1eff673174065dcfb --- /dev/null +++ b/drivers/onedrive_sharelink/types.go @@ -0,0 +1,77 @@ +package onedrive_sharelink + +import ( + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +// FolderResp represents the structure of the folder response from the OneDrive API. +type FolderResp struct { + // Data holds the nested structure of the response. + Data struct { + Legacy struct { + RenderListData struct { + ListData struct { + Items []Item `json:"Row"` // Items contains the list of items in the folder. + } `json:"ListData"` + } `json:"renderListDataAsStream"` + } `json:"legacy"` + } `json:"data"` +} + +// Item represents an individual item in the folder. +type Item struct { + ObjType string `json:"FSObjType"` // ObjType indicates if the item is a file or folder. + Name string `json:"FileLeafRef"` // Name is the name of the item. + ModifiedTime time.Time `json:"Modified."` // ModifiedTime is the last modified time of the item. + Size string `json:"File_x0020_Size"` // Size is the size of the item in string format. + Id string `json:"UniqueId"` // Id is the unique identifier of the item. +} + +// fileToObj converts an Item to an ObjThumb. +func fileToObj(f Item) *model.ObjThumb { + // Convert Size from string to int64. + size, _ := strconv.ParseInt(f.Size, 10, 64) + // Convert ObjType from string to int. + objtype, _ := strconv.Atoi(f.ObjType) + + // Create a new ObjThumb with the converted values. + file := &model.ObjThumb{ + Object: model.Object{ + Name: f.Name, + Modified: f.ModifiedTime, + Size: size, + IsFolder: objtype == 1, // Check if the item is a folder. + ID: f.Id, + }, + Thumbnail: model.Thumbnail{}, + } + return file +} + +// GraphQLNEWRequest represents the structure of a new GraphQL request. +type GraphQLNEWRequest struct { + ListData struct { + NextHref string `json:"NextHref"` // NextHref is the link to the next set of data. + Row []Item `json:"Row"` // Row contains the list of items. + } `json:"ListData"` +} + +// GraphQLRequest represents the structure of a GraphQL request. +type GraphQLRequest struct { + Data struct { + Legacy struct { + RenderListDataAsStream struct { + ListData struct { + NextHref string `json:"NextHref"` // NextHref is the link to the next set of data. + Row []Item `json:"Row"` // Row contains the list of items. + } `json:"ListData"` + ViewMetadata struct { + ListViewXml string `json:"ListViewXml"` // ListViewXml contains the XML of the list view. + } `json:"ViewMetadata"` + } `json:"renderListDataAsStream"` + } `json:"legacy"` + } `json:"data"` +} diff --git a/drivers/onedrive_sharelink/util.go b/drivers/onedrive_sharelink/util.go new file mode 100644 index 0000000000000000000000000000000000000000..4a1c92b5af81f7b04dbca4eeeb3d06ebb47e5f2a --- /dev/null +++ b/drivers/onedrive_sharelink/util.go @@ -0,0 +1,363 @@ +package onedrive_sharelink + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "regexp" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/conf" + log "github.com/sirupsen/logrus" + "golang.org/x/net/html" +) + +// NewNoRedirectClient creates an HTTP client that doesn't follow redirects +func NewNoRedirectCLient() *http.Client { + return &http.Client{ + Timeout: time.Hour * 48, + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}, + }, + // Prevent following redirects + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } +} + +// getCookiesWithPassword fetches cookies required for authenticated access using the provided password +func getCookiesWithPassword(link, password string) (string, error) { + // Send GET request + resp, err := http.Get(link) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // Parse the HTML response + doc, err := html.Parse(resp.Body) + if err != nil { + return "", err + } + + // Initialize variables to store form data + var viewstate, eventvalidation, postAction string + + // Recursive function to find input fields by their IDs + var findInputFields func(*html.Node) + findInputFields = func(n *html.Node) { + if n.Type == html.ElementNode && n.Data == "input" { + for _, attr := range n.Attr { + if attr.Key == "id" { + switch attr.Val { + case "__VIEWSTATE": + viewstate = getAttrValue(n, "value") + case "__EVENTVALIDATION": + eventvalidation = getAttrValue(n, "value") + } + } + } + } + if n.Type == html.ElementNode && n.Data == "form" { + for _, attr := range n.Attr { + if attr.Key == "id" && attr.Val == "inputForm" { + postAction = getAttrValue(n, "action") + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + findInputFields(c) + } + } + findInputFields(doc) + + // Prepare the new URL for the POST request + linkParts, err := url.Parse(link) + if err != nil { + return "", err + } + + newURL := fmt.Sprintf("%s://%s%s", linkParts.Scheme, linkParts.Host, postAction) + + // Prepare the request body + data := url.Values{ + "txtPassword": []string{password}, + "__EVENTVALIDATION": []string{eventvalidation}, + "__VIEWSTATE": []string{viewstate}, + "__VIEWSTATEENCRYPTED": []string{""}, + } + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + // Send the POST request, preventing redirects + resp, err = client.PostForm(newURL, data) + if err != nil { + return "", err + } + + // Extract the desired cookie value + cookie := resp.Cookies() + var fedAuthCookie string + for _, c := range cookie { + if c.Name == "FedAuth" { + fedAuthCookie = c.Value + break + } + } + if fedAuthCookie == "" { + return "", fmt.Errorf("wrong password") + } + return fmt.Sprintf("FedAuth=%s;", fedAuthCookie), nil +} + +// getAttrValue retrieves the value of the specified attribute from an HTML node +func getAttrValue(n *html.Node, key string) string { + for _, attr := range n.Attr { + if attr.Key == key { + return attr.Val + } + } + return "" +} + +// getHeaders constructs and returns the necessary HTTP headers for accessing the OneDrive share link +func (d *OnedriveSharelink) getHeaders() (http.Header, error) { + header := http.Header{} + header.Set("User-Agent", base.UserAgent) + header.Set("accept-language", "zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6") + + // Save current timestamp to d.HeaderTime + d.HeaderTime = time.Now().Unix() + + if d.ShareLinkPassword == "" { + // Create a no-redirect client + clientNoDirect := NewNoRedirectCLient() + req, err := http.NewRequest("GET", d.ShareLinkURL, nil) + if err != nil { + return nil, err + } + // Set headers for the request + req.Header = header + answerNoRedirect, err := clientNoDirect.Do(req) + if err != nil { + return nil, err + } + redirectUrl := answerNoRedirect.Header.Get("Location") + log.Debugln("redirectUrl:", redirectUrl) + if redirectUrl == "" { + return nil, fmt.Errorf("password protected link. Please provide password") + } + header.Set("Cookie", answerNoRedirect.Header.Get("Set-Cookie")) + header.Set("Referer", redirectUrl) + + // Extract the host part of the redirect URL and set it as the authority + u, err := url.Parse(redirectUrl) + if err != nil { + return nil, err + } + header.Set("authority", u.Host) + return header, nil + } else { + cookie, err := getCookiesWithPassword(d.ShareLinkURL, d.ShareLinkPassword) + if err != nil { + return nil, err + } + header.Set("Cookie", cookie) + header.Set("Referer", d.ShareLinkURL) + header.Set("authority", strings.Split(strings.Split(d.ShareLinkURL, "//")[1], "/")[0]) + return header, nil + } +} + +// getFiles retrieves the files from the OneDrive share link at the specified path +func (d *OnedriveSharelink) getFiles(path string) ([]Item, error) { + clientNoDirect := NewNoRedirectCLient() + req, err := http.NewRequest("GET", d.ShareLinkURL, nil) + if err != nil { + return nil, err + } + header := req.Header + redirectUrl := "" + if d.ShareLinkPassword == "" { + header.Set("User-Agent", base.UserAgent) + header.Set("accept-language", "zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6") + req.Header = header + answerNoRedirect, err := clientNoDirect.Do(req) + if err != nil { + return nil, err + } + redirectUrl = answerNoRedirect.Header.Get("Location") + } else { + header = d.Headers + req.Header = header + answerNoRedirect, err := clientNoDirect.Do(req) + if err != nil { + return nil, err + } + redirectUrl = answerNoRedirect.Header.Get("Location") + } + redirectSplitURL := strings.Split(redirectUrl, "/") + req.Header = d.Headers + downloadLinkPrefix := "" + rootFolderPre := "" + + // Determine the appropriate URL and root folder based on whether the link is SharePoint + if d.IsSharepoint { + // update req url + req.URL, err = url.Parse(redirectUrl) + if err != nil { + return nil, err + } + // Get redirectUrl + answer, err := clientNoDirect.Do(req) + if err != nil { + d.Headers, err = d.getHeaders() + if err != nil { + return nil, err + } + return d.getFiles(path) + } + defer answer.Body.Close() + re := regexp.MustCompile(`templateUrl":"(.*?)"`) + body, err := io.ReadAll(answer.Body) + if err != nil { + return nil, err + } + template := re.FindString(string(body)) + template = template[strings.Index(template, "templateUrl\":\"")+len("templateUrl\":\""):] + template = template[:strings.Index(template, "?id=")] + template = template[:strings.LastIndex(template, "/")] + downloadLinkPrefix = template + "/download.aspx?UniqueId=" + params, err := url.ParseQuery(redirectUrl[strings.Index(redirectUrl, "?")+1:]) + if err != nil { + return nil, err + } + rootFolderPre = params.Get("id") + } else { + redirectUrlCut := redirectUrl[:strings.LastIndex(redirectUrl, "/")] + downloadLinkPrefix = redirectUrlCut + "/download.aspx?UniqueId=" + params, err := url.ParseQuery(redirectUrl[strings.Index(redirectUrl, "?")+1:]) + if err != nil { + return nil, err + } + rootFolderPre = params.Get("id") + } + d.downloadLinkPrefix = downloadLinkPrefix + rootFolder, err := url.QueryUnescape(rootFolderPre) + if err != nil { + return nil, err + } + log.Debugln("rootFolder:", rootFolder) + // Extract the relative path up to and including "Documents" + relativePath := strings.Split(rootFolder, "Documents")[0] + "Documents" + + // URL encode the relative path + relativeUrl := url.QueryEscape(relativePath) + // Replace underscores and hyphens in the encoded relative path + relativeUrl = strings.Replace(relativeUrl, "_", "%5F", -1) + relativeUrl = strings.Replace(relativeUrl, "-", "%2D", -1) + + // If the path is not the root, append the path to the root folder + if path != "/" { + rootFolder = rootFolder + path + } + + // URL encode the full root folder path + rootFolderUrl := url.QueryEscape(rootFolder) + // Replace underscores and hyphens in the encoded root folder URL + rootFolderUrl = strings.Replace(rootFolderUrl, "_", "%5F", -1) + rootFolderUrl = strings.Replace(rootFolderUrl, "-", "%2D", -1) + + log.Debugln("relativePath:", relativePath, "relativeUrl:", relativeUrl, "rootFolder:", rootFolder, "rootFolderUrl:", rootFolderUrl) + + // Construct the GraphQL query with the encoded paths + graphqlVar := fmt.Sprintf(`{"query":"query (\n $listServerRelativeUrl: String!,$renderListDataAsStreamParameters: RenderListDataAsStreamParameters!,$renderListDataAsStreamQueryString: String!\n )\n {\n \n legacy {\n \n renderListDataAsStream(\n listServerRelativeUrl: $listServerRelativeUrl,\n parameters: $renderListDataAsStreamParameters,\n queryString: $renderListDataAsStreamQueryString\n )\n }\n \n \n perf {\n executionTime\n overheadTime\n parsingTime\n queryCount\n validationTime\n resolvers {\n name\n queryCount\n resolveTime\n waitTime\n }\n }\n }","variables":{"listServerRelativeUrl":"%s","renderListDataAsStreamParameters":{"renderOptions":5707527,"allowMultipleValueFilterForTaxonomyFields":true,"addRequiredFields":true,"folderServerRelativeUrl":"%s"},"renderListDataAsStreamQueryString":"@a1=\'%s\'&RootFolder=%s&TryNewExperienceSingle=TRUE"}}`, relativePath, rootFolder, relativeUrl, rootFolderUrl) + tempHeader := make(http.Header) + for k, v := range d.Headers { + tempHeader[k] = v + } + tempHeader["Content-Type"] = []string{"application/json;odata=verbose"} + + client := &http.Client{} + postUrl := strings.Join(redirectSplitURL[:len(redirectSplitURL)-3], "/") + "/_api/v2.1/graphql" + req, err = http.NewRequest("POST", postUrl, strings.NewReader(graphqlVar)) + if err != nil { + return nil, err + } + req.Header = tempHeader + + resp, err := client.Do(req) + if err != nil { + d.Headers, err = d.getHeaders() + if err != nil { + return nil, err + } + return d.getFiles(path) + } + defer resp.Body.Close() + var graphqlReq GraphQLRequest + json.NewDecoder(resp.Body).Decode(&graphqlReq) + log.Debugln("graphqlReq:", graphqlReq) + filesData := graphqlReq.Data.Legacy.RenderListDataAsStream.ListData.Row + if graphqlReq.Data.Legacy.RenderListDataAsStream.ListData.NextHref != "" { + nextHref := graphqlReq.Data.Legacy.RenderListDataAsStream.ListData.NextHref + "&@a1=REPLACEME&TryNewExperienceSingle=TRUE" + nextHref = strings.Replace(nextHref, "REPLACEME", "%27"+relativeUrl+"%27", -1) + log.Debugln("nextHref:", nextHref) + filesData = append(filesData, graphqlReq.Data.Legacy.RenderListDataAsStream.ListData.Row...) + + listViewXml := graphqlReq.Data.Legacy.RenderListDataAsStream.ViewMetadata.ListViewXml + log.Debugln("listViewXml:", listViewXml) + renderListDataAsStreamVar := `{"parameters":{"__metadata":{"type":"SP.RenderListDataParameters"},"RenderOptions":1216519,"ViewXml":"REPLACEME","AllowMultipleValueFilterForTaxonomyFields":true,"AddRequiredFields":true}}` + listViewXml = strings.Replace(listViewXml, `"`, `\"`, -1) + renderListDataAsStreamVar = strings.Replace(renderListDataAsStreamVar, "REPLACEME", listViewXml, -1) + + graphqlReqNEW := GraphQLNEWRequest{} + postUrl = strings.Join(redirectSplitURL[:len(redirectSplitURL)-3], "/") + "/_api/web/GetListUsingPath(DecodedUrl=@a1)/RenderListDataAsStream" + nextHref + req, _ = http.NewRequest("POST", postUrl, strings.NewReader(renderListDataAsStreamVar)) + req.Header = tempHeader + + resp, err := client.Do(req) + if err != nil { + d.Headers, err = d.getHeaders() + if err != nil { + return nil, err + } + return d.getFiles(path) + } + defer resp.Body.Close() + json.NewDecoder(resp.Body).Decode(&graphqlReqNEW) + for graphqlReqNEW.ListData.NextHref != "" { + graphqlReqNEW = GraphQLNEWRequest{} + postUrl = strings.Join(redirectSplitURL[:len(redirectSplitURL)-3], "/") + "/_api/web/GetListUsingPath(DecodedUrl=@a1)/RenderListDataAsStream" + nextHref + req, _ = http.NewRequest("POST", postUrl, strings.NewReader(renderListDataAsStreamVar)) + req.Header = tempHeader + resp, err := client.Do(req) + if err != nil { + d.Headers, err = d.getHeaders() + if err != nil { + return nil, err + } + return d.getFiles(path) + } + defer resp.Body.Close() + json.NewDecoder(resp.Body).Decode(&graphqlReqNEW) + nextHref = graphqlReqNEW.ListData.NextHref + "&@a1=REPLACEME&TryNewExperienceSingle=TRUE" + nextHref = strings.Replace(nextHref, "REPLACEME", "%27"+relativeUrl+"%27", -1) + filesData = append(filesData, graphqlReqNEW.ListData.Row...) + } + filesData = append(filesData, graphqlReqNEW.ListData.Row...) + } else { + filesData = append(filesData, graphqlReq.Data.Legacy.RenderListDataAsStream.ListData.Row...) + } + return filesData, nil +} diff --git a/drivers/pikpak/driver.go b/drivers/pikpak/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..dd54b70ddb32fa3f8888ecc3fadef1fcb25d0b7d --- /dev/null +++ b/drivers/pikpak/driver.go @@ -0,0 +1,364 @@ +package pikpak + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +type PikPak struct { + model.Storage + Addition + *Common + RefreshToken string + AccessToken string +} + +func (d *PikPak) Config() driver.Config { + return config +} + +func (d *PikPak) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *PikPak) Init(ctx context.Context) (err error) { + + if d.Common == nil { + d.Common = &Common{ + client: base.NewRestyClient(), + CaptchaToken: "", + UserID: "", + DeviceID: utils.GetMD5EncodeStr(d.Username + d.Password), + UserAgent: "", + RefreshCTokenCk: func(token string) { + d.Common.CaptchaToken = token + op.MustSaveDriverStorage(d) + }, + } + } + + if d.Platform == "android" { + d.ClientID = AndroidClientID + d.ClientSecret = AndroidClientSecret + d.ClientVersion = AndroidClientVersion + d.PackageName = AndroidPackageName + d.Algorithms = AndroidAlgorithms + d.UserAgent = BuildCustomUserAgent(utils.GetMD5EncodeStr(d.Username+d.Password), AndroidClientID, AndroidPackageName, AndroidSdkVersion, AndroidClientVersion, AndroidPackageName, "") + } else if d.Platform == "web" { + d.ClientID = WebClientID + d.ClientSecret = WebClientSecret + d.ClientVersion = WebClientVersion + d.PackageName = WebPackageName + d.Algorithms = WebAlgorithms + d.UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" + } else if d.Platform == "pc" { + d.ClientID = PCClientID + d.ClientSecret = PCClientSecret + d.ClientVersion = PCClientVersion + d.PackageName = PCPackageName + d.Algorithms = PCAlgorithms + d.UserAgent = "MainWindow Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) PikPak/2.6.11.4955 Chrome/100.0.4896.160 Electron/18.3.15 Safari/537.36" + } + + if d.Addition.CaptchaToken != "" && d.Addition.RefreshToken == "" { + d.SetCaptchaToken(d.Addition.CaptchaToken) + } + + if d.Addition.DeviceID != "" { + d.SetDeviceID(d.Addition.DeviceID) + } else { + d.Addition.DeviceID = d.Common.DeviceID + op.MustSaveDriverStorage(d) + } + // 如果已经有RefreshToken,直接获取AccessToken + if d.Addition.RefreshToken != "" { + if err = d.refreshToken(d.Addition.RefreshToken); err != nil { + return err + } + } else { + // 如果没有填写RefreshToken,尝试登录 获取 refreshToken + if err = d.login(); err != nil { + return err + } + } + + // 获取CaptchaToken + err = d.RefreshCaptchaTokenAtLogin(GetAction(http.MethodGet, "https://api-drive.mypikpak.net/drive/v1/files"), d.Common.GetUserID()) + if err != nil { + return err + } + + // 更新UserAgent + if d.Platform == "android" { + d.Common.UserAgent = BuildCustomUserAgent(utils.GetMD5EncodeStr(d.Username+d.Password), AndroidClientID, AndroidPackageName, AndroidSdkVersion, AndroidClientVersion, AndroidPackageName, d.Common.UserID) + } + + // 保存 有效的 RefreshToken + d.Addition.RefreshToken = d.RefreshToken + op.MustSaveDriverStorage(d) + + return nil +} + +func (d *PikPak) Drop(ctx context.Context) error { + return nil +} + +func (d *PikPak) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *PikPak) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var resp File + var url string + queryParams := map[string]string{ + "_magic": "2021", + "usage": "FETCH", + "thumbnail_size": "SIZE_LARGE", + } + if !d.DisableMediaLink { + queryParams["usage"] = "CACHE" + } + _, err := d.request(fmt.Sprintf("https://api-drive.mypikpak.net/drive/v1/files/%s", file.GetID()), + http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(queryParams) + }, &resp) + if err != nil { + return nil, err + } + url = resp.WebContentLink + + if !d.DisableMediaLink && len(resp.Medias) > 0 && resp.Medias[0].Link.Url != "" { + log.Debugln("use media link") + url = resp.Medias[0].Link.Url + } + + return &model.Link{ + URL: url, + }, nil +} + +func (d *PikPak) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/files", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "kind": "drive#folder", + "parent_id": parentDir.GetID(), + "name": dirName, + }) + }, nil) + return err +} + +func (d *PikPak) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/files:batchMove", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "ids": []string{srcObj.GetID()}, + "to": base.Json{ + "parent_id": dstDir.GetID(), + }, + }) + }, nil) + return err +} + +func (d *PikPak) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/files/"+srcObj.GetID(), http.MethodPatch, func(req *resty.Request) { + req.SetBody(base.Json{ + "name": newName, + }) + }, nil) + return err +} + +func (d *PikPak) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/files:batchCopy", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "ids": []string{srcObj.GetID()}, + "to": base.Json{ + "parent_id": dstDir.GetID(), + }, + }) + }, nil) + return err +} + +func (d *PikPak) Remove(ctx context.Context, obj model.Obj) error { + //https://api-drive.mypikpak.com/drive/v1/files:batchTrash + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/files:batchDelete", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "ids": []string{obj.GetID()}, + }) + }, nil) + return err +} + +func (d *PikPak) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + hi := stream.GetHash() + sha1Str := hi.GetHash(hash_extend.GCID) + if len(sha1Str) < hash_extend.GCID.Width { + tFile, err := stream.CacheFullInTempFile() + if err != nil { + return err + } + + sha1Str, err = utils.HashFile(hash_extend.GCID, tFile, stream.GetSize()) + if err != nil { + return err + } + } + + var resp UploadTaskData + res, err := d.request("https://api-drive.mypikpak.net/drive/v1/files", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "kind": "drive#file", + "name": stream.GetName(), + "size": stream.GetSize(), + "hash": strings.ToUpper(sha1Str), + "upload_type": "UPLOAD_TYPE_RESUMABLE", + "objProvider": base.Json{"provider": "UPLOAD_TYPE_UNKNOWN"}, + "parent_id": dstDir.GetID(), + "folder_type": "NORMAL", + }) + }, &resp) + if err != nil { + return err + } + + // 秒传成功 + if resp.Resumable == nil { + log.Debugln(string(res)) + return nil + } + + params := resp.Resumable.Params + //endpoint := strings.Join(strings.Split(params.Endpoint, ".")[1:], ".") + // web 端上传 返回的endpoint 为 `mypikpak.net` | android 端上传 返回的endpoint 为 `vip-lixian-07.mypikpak.net`· + if d.Addition.Platform == "android" { + params.Endpoint = "mypikpak.net" + } + + if stream.GetSize() <= 10*utils.MB { // 文件大小 小于10MB,改用普通模式上传 + return d.UploadByOSS(¶ms, stream, up) + } + // 分片上传 + return d.UploadByMultipart(¶ms, stream.GetSize(), stream, up) +} + +// 离线下载文件 +func (d *PikPak) Offline(ctx context.Context, args model.OtherArgs) (interface{}, error) { + requestBody := base.Json{ + "kind": "drive#file", + "name": "", + "upload_type": "UPLOAD_TYPE_URL", + "url": base.Json{ + "url": args.Data, + }, + "parent_id": args.Obj.GetID(), + "folder_type": "", + } + + _, err := d.requestWithCaptchaToken("https://api-drive.mypikpak.com/drive/v1/files", + http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(requestBody) + }, nil) + if err != nil { + return nil, err + } + return "ok", nil + + // var resp OfflineDownloadResp + // _, err := d.request("https://api-drive.mypikpak.net/drive/v1/files", http.MethodPost, func(req *resty.Request) { + // req.SetBody(requestBody) + // }, &resp) + + // if err != nil { + // return nil, err + // } + + //return &resp.Task, err +} + +/* +获取离线下载任务列表 +phase 可能的取值: +PHASE_TYPE_RUNNING, PHASE_TYPE_ERROR, PHASE_TYPE_COMPLETE, PHASE_TYPE_PENDING +*/ +func (d *PikPak) OfflineList(ctx context.Context, nextPageToken string, phase []string) ([]OfflineTask, error) { + res := make([]OfflineTask, 0) + url := "https://api-drive.mypikpak.net/drive/v1/tasks" + + if len(phase) == 0 { + phase = []string{"PHASE_TYPE_RUNNING", "PHASE_TYPE_ERROR", "PHASE_TYPE_COMPLETE", "PHASE_TYPE_PENDING"} + } + params := map[string]string{ + "type": "offline", + "thumbnail_size": "SIZE_SMALL", + "limit": "10000", + "page_token": nextPageToken, + "with": "reference_resource", + } + + // 处理 phase 参数 + if len(phase) > 0 { + filters := base.Json{ + "phase": map[string]string{ + "in": strings.Join(phase, ","), + }, + } + filtersJSON, err := json.Marshal(filters) + if err != nil { + return nil, fmt.Errorf("failed to marshal filters: %w", err) + } + params["filters"] = string(filtersJSON) + } + + var resp OfflineListResp + _, err := d.request(url, http.MethodGet, func(req *resty.Request) { + req.SetContext(ctx). + SetQueryParams(params) + }, &resp) + + if err != nil { + return nil, fmt.Errorf("failed to get offline list: %w", err) + } + res = append(res, resp.Tasks...) + return res, nil +} + +func (d *PikPak) DeleteOfflineTasks(ctx context.Context, taskIDs []string, deleteFiles bool) error { + url := "https://api-drive.mypikpak.net/drive/v1/tasks" + params := map[string]string{ + "task_ids": strings.Join(taskIDs, ","), + "delete_files": strconv.FormatBool(deleteFiles), + } + _, err := d.request(url, http.MethodDelete, func(req *resty.Request) { + req.SetContext(ctx). + SetQueryParams(params) + }, nil) + if err != nil { + return fmt.Errorf("failed to delete tasks %v: %w", taskIDs, err) + } + return nil +} + +var _ driver.Driver = (*PikPak)(nil) diff --git a/drivers/pikpak/meta.go b/drivers/pikpak/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..5abbc8796ca0f649883a9d6c4bbc21a05fb9a3cc --- /dev/null +++ b/drivers/pikpak/meta.go @@ -0,0 +1,29 @@ +package pikpak + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootID + Username string `json:"username" required:"true"` + Password string `json:"password" required:"true"` + Platform string `json:"platform" required:"true" default:"web" type:"select" options:"android,web,pc"` + RefreshToken string `json:"refresh_token" required:"true" default:""` + CaptchaToken string `json:"captcha_token" default:""` + DeviceID string `json:"device_id" required:"false" default:""` + DisableMediaLink bool `json:"disable_media_link" default:"true"` +} + +var config = driver.Config{ + Name: "PikPak", + LocalSort: true, + DefaultRoot: "", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &PikPak{} + }) +} diff --git a/drivers/pikpak/types.go b/drivers/pikpak/types.go new file mode 100644 index 0000000000000000000000000000000000000000..2a959ebf05d4f6bc6209d0560c0ce60a5cbaad34 --- /dev/null +++ b/drivers/pikpak/types.go @@ -0,0 +1,197 @@ +package pikpak + +import ( + "fmt" + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" +) + +type Files struct { + Files []File `json:"files"` + NextPageToken string `json:"next_page_token"` +} + +type File struct { + Id string `json:"id"` + Kind string `json:"kind"` + Name string `json:"name"` + CreatedTime time.Time `json:"created_time"` + ModifiedTime time.Time `json:"modified_time"` + Hash string `json:"hash"` + Size string `json:"size"` + ThumbnailLink string `json:"thumbnail_link"` + WebContentLink string `json:"web_content_link"` + Medias []Media `json:"medias"` +} + +func fileToObj(f File) *model.ObjThumb { + size, _ := strconv.ParseInt(f.Size, 10, 64) + return &model.ObjThumb{ + Object: model.Object{ + ID: f.Id, + Name: f.Name, + Size: size, + Ctime: f.CreatedTime, + Modified: f.ModifiedTime, + IsFolder: f.Kind == "drive#folder", + HashInfo: utils.NewHashInfo(hash_extend.GCID, f.Hash), + }, + Thumbnail: model.Thumbnail{ + Thumbnail: f.ThumbnailLink, + }, + } +} + +type Media struct { + MediaId string `json:"media_id"` + MediaName string `json:"media_name"` + Video struct { + Height int `json:"height"` + Width int `json:"width"` + Duration int `json:"duration"` + BitRate int `json:"bit_rate"` + FrameRate int `json:"frame_rate"` + VideoCodec string `json:"video_codec"` + AudioCodec string `json:"audio_codec"` + VideoType string `json:"video_type"` + } `json:"video"` + Link struct { + Url string `json:"url"` + Token string `json:"token"` + Expire time.Time `json:"expire"` + } `json:"link"` + NeedMoreQuota bool `json:"need_more_quota"` + VipTypes []interface{} `json:"vip_types"` + RedirectLink string `json:"redirect_link"` + IconLink string `json:"icon_link"` + IsDefault bool `json:"is_default"` + Priority int `json:"priority"` + IsOrigin bool `json:"is_origin"` + ResolutionName string `json:"resolution_name"` + IsVisible bool `json:"is_visible"` + Category string `json:"category"` +} + +type UploadTaskData struct { + UploadType string `json:"upload_type"` + //UPLOAD_TYPE_RESUMABLE + Resumable *struct { + Kind string `json:"kind"` + Params S3Params `json:"params"` + Provider string `json:"provider"` + } `json:"resumable"` + + File File `json:"file"` +} + +type S3Params struct { + AccessKeyID string `json:"access_key_id"` + AccessKeySecret string `json:"access_key_secret"` + Bucket string `json:"bucket"` + Endpoint string `json:"endpoint"` + Expiration time.Time `json:"expiration"` + Key string `json:"key"` + SecurityToken string `json:"security_token"` +} + +// 添加离线下载响应 +type OfflineDownloadResp struct { + File *string `json:"file"` + Task OfflineTask `json:"task"` + UploadType string `json:"upload_type"` + URL struct { + Kind string `json:"kind"` + } `json:"url"` +} + +// 离线下载列表 +type OfflineListResp struct { + ExpiresIn int64 `json:"expires_in"` + NextPageToken string `json:"next_page_token"` + Tasks []OfflineTask `json:"tasks"` +} + +// offlineTask +type OfflineTask struct { + Callback string `json:"callback"` + CreatedTime string `json:"created_time"` + FileID string `json:"file_id"` + FileName string `json:"file_name"` + FileSize string `json:"file_size"` + IconLink string `json:"icon_link"` + ID string `json:"id"` + Kind string `json:"kind"` + Message string `json:"message"` + Name string `json:"name"` + Params Params `json:"params"` + Phase string `json:"phase"` // PHASE_TYPE_RUNNING, PHASE_TYPE_ERROR, PHASE_TYPE_COMPLETE, PHASE_TYPE_PENDING + Progress int64 `json:"progress"` + ReferenceResource ReferenceResource `json:"reference_resource"` + Space string `json:"space"` + StatusSize int64 `json:"status_size"` + Statuses []string `json:"statuses"` + ThirdTaskID string `json:"third_task_id"` + Type string `json:"type"` + UpdatedTime string `json:"updated_time"` + UserID string `json:"user_id"` +} + +type Params struct { + Age string `json:"age"` + MIMEType *string `json:"mime_type,omitempty"` + PredictType string `json:"predict_type"` + URL string `json:"url"` +} + +type ReferenceResource struct { + Type string `json:"@type"` + Audit interface{} `json:"audit"` + Hash string `json:"hash"` + IconLink string `json:"icon_link"` + ID string `json:"id"` + Kind string `json:"kind"` + Medias []Media `json:"medias"` + MIMEType string `json:"mime_type"` + Name string `json:"name"` + Params map[string]interface{} `json:"params"` + ParentID string `json:"parent_id"` + Phase string `json:"phase"` + Size string `json:"size"` + Space string `json:"space"` + Starred bool `json:"starred"` + Tags []string `json:"tags"` + ThumbnailLink string `json:"thumbnail_link"` +} + +type ErrResp struct { + ErrorCode int64 `json:"error_code"` + ErrorMsg string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +func (e *ErrResp) IsError() bool { + return e.ErrorCode != 0 || e.ErrorMsg != "" || e.ErrorDescription != "" +} + +func (e *ErrResp) Error() string { + return fmt.Sprintf("ErrorCode: %d ,Error: %s ,ErrorDescription: %s ", e.ErrorCode, e.ErrorMsg, e.ErrorDescription) +} + +type CaptchaTokenRequest struct { + Action string `json:"action"` + CaptchaToken string `json:"captcha_token"` + ClientID string `json:"client_id"` + DeviceID string `json:"device_id"` + Meta map[string]string `json:"meta"` + RedirectUri string `json:"redirect_uri"` +} + +type CaptchaTokenResponse struct { + CaptchaToken string `json:"captcha_token"` + ExpiresIn int64 `json:"expires_in"` + Url string `json:"url"` +} diff --git a/drivers/pikpak/util.go b/drivers/pikpak/util.go new file mode 100644 index 0000000000000000000000000000000000000000..a104861b4ee799e74d385b8c3012b1fdbc6ae597 --- /dev/null +++ b/drivers/pikpak/util.go @@ -0,0 +1,692 @@ +package pikpak + +import ( + "bytes" + "crypto/md5" + "crypto/sha1" + "encoding/hex" + "fmt" + "io" + "net/http" + "path/filepath" + "regexp" + "strings" + "sync" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/aliyun/aliyun-oss-go-sdk/oss" + "github.com/go-resty/resty/v2" + jsoniter "github.com/json-iterator/go" + "github.com/pkg/errors" +) + +var AndroidAlgorithms = []string{ + "SOP04dGzk0TNO7t7t9ekDbAmx+eq0OI1ovEx", + "nVBjhYiND4hZ2NCGyV5beamIr7k6ifAsAbl", + "Ddjpt5B/Cit6EDq2a6cXgxY9lkEIOw4yC1GDF28KrA", + "VVCogcmSNIVvgV6U+AochorydiSymi68YVNGiz", + "u5ujk5sM62gpJOsB/1Gu/zsfgfZO", + "dXYIiBOAHZgzSruaQ2Nhrqc2im", + "z5jUTBSIpBN9g4qSJGlidNAutX6", + "KJE2oveZ34du/g1tiimm", +} + +var WebAlgorithms = []string{ + "C9qPpZLN8ucRTaTiUMWYS9cQvWOE", + "+r6CQVxjzJV6LCV", + "F", + "pFJRC", + "9WXYIDGrwTCz2OiVlgZa90qpECPD6olt", + "/750aCr4lm/Sly/c", + "RB+DT/gZCrbV", + "", + "CyLsf7hdkIRxRm215hl", + "7xHvLi2tOYP0Y92b", + "ZGTXXxu8E/MIWaEDB+Sm/", + "1UI3", + "E7fP5Pfijd+7K+t6Tg/NhuLq0eEUVChpJSkrKxpO", + "ihtqpG6FMt65+Xk+tWUH2", + "NhXXU9rg4XXdzo7u5o", +} + +var PCAlgorithms = []string{ + "KHBJ07an7ROXDoK7Db", + "G6n399rSWkl7WcQmw5rpQInurc1DkLmLJqE", + "JZD1A3M4x+jBFN62hkr7VDhkkZxb9g3rWqRZqFAAb", + "fQnw/AmSlbbI91Ik15gpddGgyU7U", + "/Dv9JdPYSj3sHiWjouR95NTQff", + "yGx2zuTjbWENZqecNI+edrQgqmZKP", + "ljrbSzdHLwbqcRn", + "lSHAsqCkGDGxQqqwrVu", + "TsWXI81fD1", + "vk7hBjawK/rOSrSWajtbMk95nfgf3", +} + +const ( + OSSUserAgent = "aliyun-sdk-android/2.9.13(Linux/Android 14/M2004j7ac;UKQ1.231108.001)" + OssSecurityTokenHeaderName = "X-OSS-Security-Token" + ThreadsNum = 10 +) + +const ( + AndroidClientID = "YNxT9w7GMdWvEOKa" + AndroidClientSecret = "dbw2OtmVEeuUvIptb1Coyg" + AndroidClientVersion = "1.53.2" + AndroidPackageName = "com.pikcloud.pikpak" + AndroidSdkVersion = "2.0.6.206003" + WebClientID = "YUMx5nI8ZU8Ap8pm" + WebClientSecret = "dbw2OtmVEeuUvIptb1Coyg" + WebClientVersion = "2.0.0" + WebPackageName = "mypikpak.com" + WebSdkVersion = "8.0.3" + PCClientID = "YvtoWO6GNHiuCl7x" + PCClientSecret = "1NIH5R1IEe2pAxZE3hv3uA" + PCClientVersion = "undefined" // 2.6.11.4955 + PCPackageName = "mypikpak.com" + PCSdkVersion = "8.0.3" +) + +func (d *PikPak) login() error { + // 检查用户名和密码是否为空 + if d.Addition.Username == "" || d.Addition.Password == "" { + return errors.New("username or password is empty") + } + + url := "https://user.mypikpak.net/v1/auth/signin" + // 使用 用户填写的 CaptchaToken —————— (验证后的captcha_token) + if d.GetCaptchaToken() == "" { + if err := d.RefreshCaptchaTokenInLogin(GetAction(http.MethodPost, url), d.Username); err != nil { + return err + } + } + + var e ErrResp + res, err := base.RestyClient.SetRetryCount(1).R().SetError(&e).SetBody(base.Json{ + "captcha_token": d.GetCaptchaToken(), + "client_id": d.ClientID, + "client_secret": d.ClientSecret, + "username": d.Username, + "password": d.Password, + }).SetQueryParam("client_id", d.ClientID).Post(url) + if err != nil { + return err + } + if e.ErrorCode != 0 { + return &e + } + data := res.Body() + d.RefreshToken = jsoniter.Get(data, "refresh_token").ToString() + d.AccessToken = jsoniter.Get(data, "access_token").ToString() + d.Common.SetUserID(jsoniter.Get(data, "sub").ToString()) + return nil +} + +func (d *PikPak) refreshToken(refreshToken string) error { + url := "https://user.mypikpak.net/v1/auth/token" + var e ErrResp + res, err := base.RestyClient.SetRetryCount(1).R().SetError(&e). + SetHeader("user-agent", "").SetBody(base.Json{ + "client_id": d.ClientID, + "client_secret": d.ClientSecret, + "grant_type": "refresh_token", + "refresh_token": refreshToken, + }).SetQueryParam("client_id", d.ClientID).Post(url) + if err != nil { + d.Status = err.Error() + op.MustSaveDriverStorage(d) + return err + } + if e.ErrorCode != 0 { + if e.ErrorCode == 4126 { + // 1. 未填写 username 或 password + if d.Addition.Username == "" || d.Addition.Password == "" { + return errors.New("refresh_token invalid, please re-provide refresh_token") + } else { + // refresh_token invalid, re-login + return d.login() + } + } + d.Status = e.Error() + op.MustSaveDriverStorage(d) + return errors.New(e.Error()) + } + data := res.Body() + d.Status = "work" + d.RefreshToken = jsoniter.Get(data, "refresh_token").ToString() + d.AccessToken = jsoniter.Get(data, "access_token").ToString() + d.Common.SetUserID(jsoniter.Get(data, "sub").ToString()) + d.Addition.RefreshToken = d.RefreshToken + op.MustSaveDriverStorage(d) + return nil +} + +func (d *PikPak) request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + //"Authorization": "Bearer " + d.AccessToken, + "User-Agent": d.GetUserAgent(), + "X-Device-ID": d.GetDeviceID(), + "X-Captcha-Token": d.GetCaptchaToken(), + }) + if d.AccessToken != "" { + req.SetHeader("Authorization", "Bearer "+d.AccessToken) + } + + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e ErrResp + req.SetError(&e) + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + + switch e.ErrorCode { + case 0: + return res.Body(), nil + case 4122, 4121, 16: + // access_token 过期 + if err1 := d.refreshToken(d.RefreshToken); err1 != nil { + return nil, err1 + } + return d.request(url, method, callback, resp) + case 9: // 验证码token过期 + if err = d.RefreshCaptchaTokenAtLogin(GetAction(method, url), d.GetUserID()); err != nil { + return nil, err + } + return d.request(url, method, callback, resp) + case 10: // 操作频繁 + return nil, errors.New(e.ErrorDescription) + default: + return nil, errors.New(e.Error()) + } +} + +func (d *PikPak) requestWithCaptchaToken(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + + if err := d.RefreshCaptchaTokenAtLogin(GetAction(method, url), d.Common.UserID); err != nil { + return nil, err + } + + data, err := d.request(url, method, func(req *resty.Request) { + req.SetHeaders(map[string]string{ + "User-Agent": d.GetUserAgent(), + "X-Device-ID": d.GetDeviceID(), + "X-Captcha-Token": d.GetCaptchaToken(), + }) + if callback != nil { + callback(req) + } + }, resp) + + errResp, ok := err.(*ErrResp) + + if !ok { + return nil, err + } + + switch errResp.ErrorCode { + case 0: + return data, nil + //case 4122, 4121, 10, 16: + // if d.refreshTokenFunc != nil { + // if err = xc.refreshTokenFunc(); err == nil { + // break + // } + // } + // return nil, err + case 9: // 验证码token过期 + if err = d.RefreshCaptchaTokenAtLogin(GetAction(method, url), d.Common.UserID); err != nil { + return nil, err + } + default: + return nil, err + } + return d.request(url, method, callback, resp) +} + +func (d *PikPak) getFiles(id string) ([]File, error) { + res := make([]File, 0) + pageToken := "first" + for pageToken != "" { + if pageToken == "first" { + pageToken = "" + } + query := map[string]string{ + "parent_id": id, + "thumbnail_size": "SIZE_LARGE", + "with_audit": "true", + "limit": "100", + "filters": `{"phase":{"eq":"PHASE_TYPE_COMPLETE"},"trashed":{"eq":false}}`, + "page_token": pageToken, + } + var resp Files + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/files", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return nil, err + } + pageToken = resp.NextPageToken + res = append(res, resp.Files...) + } + return res, nil +} + +func GetAction(method string, url string) string { + urlpath := regexp.MustCompile(`://[^/]+((/[^/\s?#]+)*)`).FindStringSubmatch(url)[1] + return method + ":" + urlpath +} + +type Common struct { + client *resty.Client + CaptchaToken string + UserID string + // 必要值,签名相关 + ClientID string + ClientSecret string + ClientVersion string + PackageName string + Algorithms []string + DeviceID string + UserAgent string + // 验证码token刷新成功回调 + RefreshCTokenCk func(token string) +} + +func generateDeviceSign(deviceID, packageName string) string { + + signatureBase := fmt.Sprintf("%s%s%s%s", deviceID, packageName, "1", "appkey") + + sha1Hash := sha1.New() + sha1Hash.Write([]byte(signatureBase)) + sha1Result := sha1Hash.Sum(nil) + + sha1String := hex.EncodeToString(sha1Result) + + md5Hash := md5.New() + md5Hash.Write([]byte(sha1String)) + md5Result := md5Hash.Sum(nil) + + md5String := hex.EncodeToString(md5Result) + + deviceSign := fmt.Sprintf("div101.%s%s", deviceID, md5String) + + return deviceSign +} + +func BuildCustomUserAgent(deviceID, clientID, appName, sdkVersion, clientVersion, packageName, userID string) string { + deviceSign := generateDeviceSign(deviceID, packageName) + var sb strings.Builder + + sb.WriteString(fmt.Sprintf("ANDROID-%s/%s ", appName, clientVersion)) + sb.WriteString("protocolVersion/200 ") + sb.WriteString("accesstype/ ") + sb.WriteString(fmt.Sprintf("clientid/%s ", clientID)) + sb.WriteString(fmt.Sprintf("clientversion/%s ", clientVersion)) + sb.WriteString("action_type/ ") + sb.WriteString("networktype/WIFI ") + sb.WriteString("sessionid/ ") + sb.WriteString(fmt.Sprintf("deviceid/%s ", deviceID)) + sb.WriteString("providername/NONE ") + sb.WriteString(fmt.Sprintf("devicesign/%s ", deviceSign)) + sb.WriteString("refresh_token/ ") + sb.WriteString(fmt.Sprintf("sdkversion/%s ", sdkVersion)) + sb.WriteString(fmt.Sprintf("datetime/%d ", time.Now().UnixMilli())) + sb.WriteString(fmt.Sprintf("usrno/%s ", userID)) + sb.WriteString(fmt.Sprintf("appname/android-%s ", appName)) + sb.WriteString(fmt.Sprintf("session_origin/ ")) + sb.WriteString(fmt.Sprintf("grant_type/ ")) + sb.WriteString(fmt.Sprintf("appid/ ")) + sb.WriteString(fmt.Sprintf("clientip/ ")) + sb.WriteString(fmt.Sprintf("devicename/Xiaomi_M2004j7ac ")) + sb.WriteString(fmt.Sprintf("osversion/13 ")) + sb.WriteString(fmt.Sprintf("platformversion/10 ")) + sb.WriteString(fmt.Sprintf("accessmode/ ")) + sb.WriteString(fmt.Sprintf("devicemodel/M2004J7AC ")) + + return sb.String() +} + +func (c *Common) SetDeviceID(deviceID string) { + c.DeviceID = deviceID +} + +func (c *Common) SetUserID(userID string) { + c.UserID = userID +} + +func (c *Common) SetUserAgent(userAgent string) { + c.UserAgent = userAgent +} + +func (c *Common) SetCaptchaToken(captchaToken string) { + c.CaptchaToken = captchaToken +} +func (c *Common) GetCaptchaToken() string { + return c.CaptchaToken +} + +func (c *Common) GetUserAgent() string { + return c.UserAgent +} + +func (c *Common) GetDeviceID() string { + return c.DeviceID +} + +func (c *Common) GetUserID() string { + return c.UserID +} + +// RefreshCaptchaTokenAtLogin 刷新验证码token(登录后) +func (d *PikPak) RefreshCaptchaTokenAtLogin(action, userID string) error { + metas := map[string]string{ + "client_version": d.ClientVersion, + "package_name": d.PackageName, + "user_id": userID, + } + metas["timestamp"], metas["captcha_sign"] = d.Common.GetCaptchaSign() + return d.refreshCaptchaToken(action, metas) +} + +// RefreshCaptchaTokenInLogin 刷新验证码token(登录时) +func (d *PikPak) RefreshCaptchaTokenInLogin(action, username string) error { + metas := make(map[string]string) + if ok, _ := regexp.MatchString(`\w+([-+.]\w+)*@\w+([-.]\w+)*\.\w+([-.]\w+)*`, username); ok { + metas["email"] = username + } else if len(username) >= 11 && len(username) <= 18 { + metas["phone_number"] = username + } else { + metas["username"] = username + } + return d.refreshCaptchaToken(action, metas) +} + +// GetCaptchaSign 获取验证码签名 +func (c *Common) GetCaptchaSign() (timestamp, sign string) { + timestamp = fmt.Sprint(time.Now().UnixMilli()) + str := fmt.Sprint(c.ClientID, c.ClientVersion, c.PackageName, c.DeviceID, timestamp) + for _, algorithm := range c.Algorithms { + str = utils.GetMD5EncodeStr(str + algorithm) + } + sign = "1." + str + return +} + +// refreshCaptchaToken 刷新CaptchaToken +func (d *PikPak) refreshCaptchaToken(action string, metas map[string]string) error { + param := CaptchaTokenRequest{ + Action: action, + CaptchaToken: d.GetCaptchaToken(), + ClientID: d.ClientID, + DeviceID: d.GetDeviceID(), + Meta: metas, + RedirectUri: "xlaccsdk01://xbase.cloud/callback?state=harbor", + } + var e ErrResp + var resp CaptchaTokenResponse + _, err := d.request("https://user.mypikpak.net/v1/shield/captcha/init", http.MethodPost, func(req *resty.Request) { + req.SetError(&e).SetBody(param).SetQueryParam("client_id", d.ClientID) + }, &resp) + + if err != nil { + return err + } + + if e.IsError() { + return errors.New(e.Error()) + } + + if resp.Url != "" { + return fmt.Errorf(`need verify: Click Here`, resp.Url) + } + + if d.Common.RefreshCTokenCk != nil { + d.Common.RefreshCTokenCk(resp.CaptchaToken) + } + d.Common.SetCaptchaToken(resp.CaptchaToken) + return nil +} + +func (d *PikPak) UploadByOSS(params *S3Params, stream model.FileStreamer, up driver.UpdateProgress) error { + ossClient, err := oss.New(params.Endpoint, params.AccessKeyID, params.AccessKeySecret) + if err != nil { + return err + } + bucket, err := ossClient.Bucket(params.Bucket) + if err != nil { + return err + } + + err = bucket.PutObject(params.Key, stream, OssOption(params)...) + if err != nil { + return err + } + return nil +} +func (d *PikPak) UploadByMultipart(params *S3Params, fileSize int64, stream model.FileStreamer, up driver.UpdateProgress) error { + var ( + chunks []oss.FileChunk + parts []oss.UploadPart + imur oss.InitiateMultipartUploadResult + ossClient *oss.Client + bucket *oss.Bucket + err error + ) + + tmpF, err := stream.CacheFullInTempFile() + if err != nil { + return err + } + + if ossClient, err = oss.New(params.Endpoint, params.AccessKeyID, params.AccessKeySecret); err != nil { + return err + } + + if bucket, err = ossClient.Bucket(params.Bucket); err != nil { + return err + } + + ticker := time.NewTicker(time.Hour * 12) + defer ticker.Stop() + // 设置超时 + timeout := time.NewTimer(time.Hour * 24) + + if chunks, err = SplitFile(fileSize); err != nil { + return err + } + + if imur, err = bucket.InitiateMultipartUpload(params.Key, + oss.SetHeader(OssSecurityTokenHeaderName, params.SecurityToken), + oss.UserAgentHeader(OSSUserAgent), + ); err != nil { + return err + } + + wg := sync.WaitGroup{} + wg.Add(len(chunks)) + + chunksCh := make(chan oss.FileChunk) + errCh := make(chan error) + UploadedPartsCh := make(chan oss.UploadPart) + quit := make(chan struct{}) + + // producer + go chunksProducer(chunksCh, chunks) + go func() { + wg.Wait() + quit <- struct{}{} + }() + + // consumers + for i := 0; i < ThreadsNum; i++ { + go func(threadId int) { + defer func() { + if r := recover(); r != nil { + errCh <- fmt.Errorf("recovered in %v", r) + } + }() + for chunk := range chunksCh { + var part oss.UploadPart // 出现错误就继续尝试,共尝试3次 + for retry := 0; retry < 3; retry++ { + select { + case <-ticker.C: + errCh <- errors.Wrap(err, "ossToken 过期") + default: + } + + buf := make([]byte, chunk.Size) + if _, err = tmpF.ReadAt(buf, chunk.Offset); err != nil && !errors.Is(err, io.EOF) { + continue + } + + b := bytes.NewBuffer(buf) + if part, err = bucket.UploadPart(imur, b, chunk.Size, chunk.Number, OssOption(params)...); err == nil { + break + } + } + if err != nil { + errCh <- errors.Wrap(err, fmt.Sprintf("上传 %s 的第%d个分片时出现错误:%v", stream.GetName(), chunk.Number, err)) + } + UploadedPartsCh <- part + } + }(i) + } + + go func() { + for part := range UploadedPartsCh { + parts = append(parts, part) + wg.Done() + } + }() +LOOP: + for { + select { + case <-ticker.C: + // ossToken 过期 + return err + case <-quit: + break LOOP + case <-errCh: + return err + case <-timeout.C: + return fmt.Errorf("time out") + } + } + + // EOF错误是xml的Unmarshal导致的,响应其实是json格式,所以实际上上传是成功的 + if _, err = bucket.CompleteMultipartUpload(imur, parts, OssOption(params)...); err != nil && !errors.Is(err, io.EOF) { + // 当文件名含有 &< 这两个字符之一时响应的xml解析会出现错误,实际上上传是成功的 + if filename := filepath.Base(stream.GetName()); !strings.ContainsAny(filename, "&<") { + return err + } + } + return nil +} + +func chunksProducer(ch chan oss.FileChunk, chunks []oss.FileChunk) { + for _, chunk := range chunks { + ch <- chunk + } +} + +func SplitFile(fileSize int64) (chunks []oss.FileChunk, err error) { + for i := int64(1); i < 10; i++ { + if fileSize < i*utils.GB { // 文件大小小于iGB时分为i*100片 + if chunks, err = SplitFileByPartNum(fileSize, int(i*100)); err != nil { + return + } + break + } + } + if fileSize > 9*utils.GB { // 文件大小大于9GB时分为1000片 + if chunks, err = SplitFileByPartNum(fileSize, 1000); err != nil { + return + } + } + // 单个分片大小不能小于1MB + if chunks[0].Size < 1*utils.MB { + if chunks, err = SplitFileByPartSize(fileSize, 1*utils.MB); err != nil { + return + } + } + return +} + +// SplitFileByPartNum splits big file into parts by the num of parts. +// Split the file with specified parts count, returns the split result when error is nil. +func SplitFileByPartNum(fileSize int64, chunkNum int) ([]oss.FileChunk, error) { + if chunkNum <= 0 || chunkNum > 10000 { + return nil, errors.New("chunkNum invalid") + } + + if int64(chunkNum) > fileSize { + return nil, errors.New("oss: chunkNum invalid") + } + + var chunks []oss.FileChunk + chunk := oss.FileChunk{} + chunkN := (int64)(chunkNum) + for i := int64(0); i < chunkN; i++ { + chunk.Number = int(i + 1) + chunk.Offset = i * (fileSize / chunkN) + if i == chunkN-1 { + chunk.Size = fileSize/chunkN + fileSize%chunkN + } else { + chunk.Size = fileSize / chunkN + } + chunks = append(chunks, chunk) + } + + return chunks, nil +} + +// SplitFileByPartSize splits big file into parts by the size of parts. +// Splits the file by the part size. Returns the FileChunk when error is nil. +func SplitFileByPartSize(fileSize int64, chunkSize int64) ([]oss.FileChunk, error) { + if chunkSize <= 0 { + return nil, errors.New("chunkSize invalid") + } + + chunkN := fileSize / chunkSize + if chunkN >= 10000 { + return nil, errors.New("Too many parts, please increase part size") + } + + var chunks []oss.FileChunk + chunk := oss.FileChunk{} + for i := int64(0); i < chunkN; i++ { + chunk.Number = int(i + 1) + chunk.Offset = i * chunkSize + chunk.Size = chunkSize + chunks = append(chunks, chunk) + } + + if fileSize%chunkSize > 0 { + chunk.Number = len(chunks) + 1 + chunk.Offset = int64(len(chunks)) * chunkSize + chunk.Size = fileSize % chunkSize + chunks = append(chunks, chunk) + } + + return chunks, nil +} + +// OssOption get options +func OssOption(params *S3Params) []oss.Option { + options := []oss.Option{ + oss.SetHeader(OssSecurityTokenHeaderName, params.SecurityToken), + oss.UserAgentHeader(OSSUserAgent), + } + return options +} diff --git a/drivers/pikpak_proxy/driver.go b/drivers/pikpak_proxy/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..0e913888e0e9dd00dfcd0b1e44162bc090bcd538 --- /dev/null +++ b/drivers/pikpak_proxy/driver.go @@ -0,0 +1,375 @@ +package PikPakProxy + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +type PikPakProxy struct { + model.Storage + Addition + *Common + RefreshToken string + AccessToken string +} + +func (d *PikPakProxy) Config() driver.Config { + return config +} + +func (d *PikPakProxy) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *PikPakProxy) Init(ctx context.Context) (err error) { + + if d.Common == nil { + d.Common = &Common{ + client: base.NewRestyClient(), + CaptchaToken: "", + UserID: "", + DeviceID: utils.GetMD5EncodeStr(d.Username + d.Password), + UserAgent: "", + RefreshCTokenCk: func(token string) { + d.Common.CaptchaToken = token + op.MustSaveDriverStorage(d) + }, + UseProxy: d.Addition.UseProxy, + ProxyUrl: d.Addition.ProxyUrl, + } + } + + if d.Platform == "android" { + d.ClientID = AndroidClientID + d.ClientSecret = AndroidClientSecret + d.ClientVersion = AndroidClientVersion + d.PackageName = AndroidPackageName + d.Algorithms = AndroidAlgorithms + d.UserAgent = BuildCustomUserAgent(utils.GetMD5EncodeStr(d.Username+d.Password), AndroidClientID, AndroidPackageName, AndroidSdkVersion, AndroidClientVersion, AndroidPackageName, "") + } else if d.Platform == "web" { + d.ClientID = WebClientID + d.ClientSecret = WebClientSecret + d.ClientVersion = WebClientVersion + d.PackageName = WebPackageName + d.Algorithms = WebAlgorithms + d.UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" + } else if d.Platform == "pc" { + d.ClientID = PCClientID + d.ClientSecret = PCClientSecret + d.ClientVersion = PCClientVersion + d.PackageName = PCPackageName + d.Algorithms = PCAlgorithms + d.UserAgent = "MainWindow Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) PikPak/2.6.11.4955 Chrome/100.0.4896.160 Electron/18.3.15 Safari/537.36" + } + + if d.Addition.CaptchaToken != "" && d.Addition.RefreshToken == "" { + d.SetCaptchaToken(d.Addition.CaptchaToken) + } + + if d.Addition.DeviceID != "" { + d.SetDeviceID(d.Addition.DeviceID) + } else { + d.Addition.DeviceID = d.Common.DeviceID + op.MustSaveDriverStorage(d) + } + // 如果已经有RefreshToken,直接获取AccessToken + if d.Addition.RefreshToken != "" { + if err = d.refreshToken(d.Addition.RefreshToken); err != nil { + return err + } + } else { + // 如果没有填写RefreshToken,尝试登录 获取 refreshToken + if err = d.login(); err != nil { + return err + } + } + + // 获取CaptchaToken + err = d.RefreshCaptchaTokenAtLogin(GetAction(http.MethodGet, "https://api-drive.mypikpak.net/drive/v1/files"), d.Common.GetUserID()) + if err != nil { + return err + } + + // 更新UserAgent + if d.Platform == "android" { + d.Common.UserAgent = BuildCustomUserAgent(utils.GetMD5EncodeStr(d.Username+d.Password), AndroidClientID, AndroidPackageName, AndroidSdkVersion, AndroidClientVersion, AndroidPackageName, d.Common.UserID) + } + + // 保存 有效的 RefreshToken + d.Addition.RefreshToken = d.RefreshToken + op.MustSaveDriverStorage(d) + + return nil +} + +func (d *PikPakProxy) Drop(ctx context.Context) error { + return nil +} + +func (d *PikPakProxy) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *PikPakProxy) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var resp File + var url string + queryParams := map[string]string{ + "_magic": "2021", + "usage": "FETCH", + "thumbnail_size": "SIZE_LARGE", + } + if !d.DisableMediaLink { + queryParams["usage"] = "CACHE" + } + _, err := d.request(fmt.Sprintf("https://api-drive.mypikpak.net/drive/v1/files/%s", file.GetID()), + http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(queryParams) + }, &resp) + if err != nil { + return nil, err + } + url = resp.WebContentLink + + if !d.DisableMediaLink && len(resp.Medias) > 0 && resp.Medias[0].Link.Url != "" { + log.Debugln("use media link") + url = resp.Medias[0].Link.Url + } + + if d.Addition.UseProxy { + if strings.HasSuffix(d.Addition.ProxyUrl, "/") { + url = d.Addition.ProxyUrl + url + } else { + url = d.Addition.ProxyUrl + "/" + url + } + + } + + return &model.Link{ + URL: url, + }, nil +} + +func (d *PikPakProxy) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/files", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "kind": "drive#folder", + "parent_id": parentDir.GetID(), + "name": dirName, + }) + }, nil) + return err +} + +func (d *PikPakProxy) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/files:batchMove", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "ids": []string{srcObj.GetID()}, + "to": base.Json{ + "parent_id": dstDir.GetID(), + }, + }) + }, nil) + return err +} + +func (d *PikPakProxy) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/files/"+srcObj.GetID(), http.MethodPatch, func(req *resty.Request) { + req.SetBody(base.Json{ + "name": newName, + }) + }, nil) + return err +} + +func (d *PikPakProxy) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/files:batchCopy", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "ids": []string{srcObj.GetID()}, + "to": base.Json{ + "parent_id": dstDir.GetID(), + }, + }) + }, nil) + return err +} + +func (d *PikPakProxy) Remove(ctx context.Context, obj model.Obj) error { + // https://api-drive.mypikpak.com/drive/v1/files:batchTrash + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/files:batchDelete", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "ids": []string{obj.GetID()}, + }) + }, nil) + return err +} + +func (d *PikPakProxy) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + hi := stream.GetHash() + sha1Str := hi.GetHash(hash_extend.GCID) + if len(sha1Str) < hash_extend.GCID.Width { + tFile, err := stream.CacheFullInTempFile() + if err != nil { + return err + } + + sha1Str, err = utils.HashFile(hash_extend.GCID, tFile, stream.GetSize()) + if err != nil { + return err + } + } + + var resp UploadTaskData + res, err := d.request("https://api-drive.mypikpak.net/drive/v1/files", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "kind": "drive#file", + "name": stream.GetName(), + "size": stream.GetSize(), + "hash": strings.ToUpper(sha1Str), + "upload_type": "UPLOAD_TYPE_RESUMABLE", + "objProvider": base.Json{"provider": "UPLOAD_TYPE_UNKNOWN"}, + "parent_id": dstDir.GetID(), + "folder_type": "NORMAL", + }) + }, &resp) + if err != nil { + return err + } + + // 秒传成功 + if resp.Resumable == nil { + log.Debugln(string(res)) + return nil + } + + params := resp.Resumable.Params + //endpoint := strings.Join(strings.Split(params.Endpoint, ".")[1:], ".") + // web 端上传 返回的endpoint 为 `mypikpak.net` | android 端上传 返回的endpoint 为 `vip-lixian-07.mypikpak.net`· + if d.Addition.Platform == "android" { + params.Endpoint = "mypikpak.net" + } + + if stream.GetSize() <= 10*utils.MB { // 文件大小 小于10MB,改用普通模式上传 + return d.UploadByOSS(¶ms, stream, up) + } + // 分片上传 + return d.UploadByMultipart(¶ms, stream.GetSize(), stream, up) +} + +// 离线下载文件 +func (d *PikPakProxy) Offline(ctx context.Context, args model.OtherArgs) (interface{}, error) { + requestBody := base.Json{ + "kind": "drive#file", + "name": "", + "upload_type": "UPLOAD_TYPE_URL", + "url": base.Json{ + "url": args.Data, + }, + "parent_id": args.Obj.GetID(), + "folder_type": "", + } + + _, err := d.requestWithCaptchaToken("https://api-drive.mypikpak.com/drive/v1/files", + http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(requestBody) + }, nil) + if err != nil { + return nil, err + } + return "ok", nil + + // var resp OfflineDownloadResp + // _, err := d.request("https://api-drive.mypikpak.net/drive/v1/files", http.MethodPost, func(req *resty.Request) { + // req.SetBody(requestBody) + // }, &resp) + + // if err != nil { + // return nil, err + // } + + //return &resp.Task, err +} + +/* +获取离线下载任务列表 +phase 可能的取值: +PHASE_TYPE_RUNNING, PHASE_TYPE_ERROR, PHASE_TYPE_COMPLETE, PHASE_TYPE_PENDING +*/ +func (d *PikPakProxy) OfflineList(ctx context.Context, nextPageToken string, phase []string) ([]OfflineTask, error) { + res := make([]OfflineTask, 0) + url := "https://api-drive.mypikpak.net/drive/v1/tasks" + + if len(phase) == 0 { + phase = []string{"PHASE_TYPE_RUNNING", "PHASE_TYPE_ERROR", "PHASE_TYPE_COMPLETE", "PHASE_TYPE_PENDING"} + } + params := map[string]string{ + "type": "offline", + "thumbnail_size": "SIZE_SMALL", + "limit": "10000", + "page_token": nextPageToken, + "with": "reference_resource", + } + + // 处理 phase 参数 + if len(phase) > 0 { + filters := base.Json{ + "phase": map[string]string{ + "in": strings.Join(phase, ","), + }, + } + filtersJSON, err := json.Marshal(filters) + if err != nil { + return nil, fmt.Errorf("failed to marshal filters: %w", err) + } + params["filters"] = string(filtersJSON) + } + + var resp OfflineListResp + _, err := d.request(url, http.MethodGet, func(req *resty.Request) { + req.SetContext(ctx). + SetQueryParams(params) + }, &resp) + + if err != nil { + return nil, fmt.Errorf("failed to get offline list: %w", err) + } + res = append(res, resp.Tasks...) + return res, nil +} + +func (d *PikPakProxy) DeleteOfflineTasks(ctx context.Context, taskIDs []string, deleteFiles bool) error { + url := "https://api-drive.mypikpak.net/drive/v1/tasks" + params := map[string]string{ + "task_ids": strings.Join(taskIDs, ","), + "delete_files": strconv.FormatBool(deleteFiles), + } + _, err := d.request(url, http.MethodDelete, func(req *resty.Request) { + req.SetContext(ctx). + SetQueryParams(params) + }, nil) + if err != nil { + return fmt.Errorf("failed to delete tasks %v: %w", taskIDs, err) + } + return nil +} + +var _ driver.Driver = (*PikPakProxy)(nil) diff --git a/drivers/pikpak_proxy/meta.go b/drivers/pikpak_proxy/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..978038aa22429c7240886d70fdc62b8630ea421a --- /dev/null +++ b/drivers/pikpak_proxy/meta.go @@ -0,0 +1,33 @@ +package PikPakProxy + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootID + Username string `json:"username" required:"true"` + Password string `json:"password" required:"true"` + Platform string `json:"platform" required:"true" default:"web" type:"select" options:"android,web,pc"` + RefreshToken string `json:"refresh_token" required:"true" default:""` + CaptchaToken string `json:"captcha_token" default:""` + DeviceID string `json:"device_id" required:"false" default:""` + DisableMediaLink bool `json:"disable_media_link" default:"true"` + //是否使用代理 + UseProxy bool `json:"use_proxy"` + //下代理地址 + ProxyUrl string `json:"proxy_url" default:""` +} + +var config = driver.Config{ + Name: "PikPakProxy", + LocalSort: true, + DefaultRoot: "", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &PikPakProxy{} + }) +} diff --git a/drivers/pikpak_proxy/types.go b/drivers/pikpak_proxy/types.go new file mode 100644 index 0000000000000000000000000000000000000000..fb159944177917434552e240e120bb7bc126e09b --- /dev/null +++ b/drivers/pikpak_proxy/types.go @@ -0,0 +1,197 @@ +package PikPakProxy + +import ( + "fmt" + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" +) + +type Files struct { + Files []File `json:"files"` + NextPageToken string `json:"next_page_token"` +} + +type File struct { + Id string `json:"id"` + Kind string `json:"kind"` + Name string `json:"name"` + CreatedTime time.Time `json:"created_time"` + ModifiedTime time.Time `json:"modified_time"` + Hash string `json:"hash"` + Size string `json:"size"` + ThumbnailLink string `json:"thumbnail_link"` + WebContentLink string `json:"web_content_link"` + Medias []Media `json:"medias"` +} + +func fileToObj(f File) *model.ObjThumb { + size, _ := strconv.ParseInt(f.Size, 10, 64) + return &model.ObjThumb{ + Object: model.Object{ + ID: f.Id, + Name: f.Name, + Size: size, + Ctime: f.CreatedTime, + Modified: f.ModifiedTime, + IsFolder: f.Kind == "drive#folder", + HashInfo: utils.NewHashInfo(hash_extend.GCID, f.Hash), + }, + Thumbnail: model.Thumbnail{ + Thumbnail: f.ThumbnailLink, + }, + } +} + +type Media struct { + MediaId string `json:"media_id"` + MediaName string `json:"media_name"` + Video struct { + Height int `json:"height"` + Width int `json:"width"` + Duration int `json:"duration"` + BitRate int `json:"bit_rate"` + FrameRate int `json:"frame_rate"` + VideoCodec string `json:"video_codec"` + AudioCodec string `json:"audio_codec"` + VideoType string `json:"video_type"` + } `json:"video"` + Link struct { + Url string `json:"url"` + Token string `json:"token"` + Expire time.Time `json:"expire"` + } `json:"link"` + NeedMoreQuota bool `json:"need_more_quota"` + VipTypes []interface{} `json:"vip_types"` + RedirectLink string `json:"redirect_link"` + IconLink string `json:"icon_link"` + IsDefault bool `json:"is_default"` + Priority int `json:"priority"` + IsOrigin bool `json:"is_origin"` + ResolutionName string `json:"resolution_name"` + IsVisible bool `json:"is_visible"` + Category string `json:"category"` +} + +type UploadTaskData struct { + UploadType string `json:"upload_type"` + //UPLOAD_TYPE_RESUMABLE + Resumable *struct { + Kind string `json:"kind"` + Params S3Params `json:"params"` + Provider string `json:"provider"` + } `json:"resumable"` + + File File `json:"file"` +} + +type S3Params struct { + AccessKeyID string `json:"access_key_id"` + AccessKeySecret string `json:"access_key_secret"` + Bucket string `json:"bucket"` + Endpoint string `json:"endpoint"` + Expiration time.Time `json:"expiration"` + Key string `json:"key"` + SecurityToken string `json:"security_token"` +} + +// 添加离线下载响应 +type OfflineDownloadResp struct { + File *string `json:"file"` + Task OfflineTask `json:"task"` + UploadType string `json:"upload_type"` + URL struct { + Kind string `json:"kind"` + } `json:"url"` +} + +// 离线下载列表 +type OfflineListResp struct { + ExpiresIn int64 `json:"expires_in"` + NextPageToken string `json:"next_page_token"` + Tasks []OfflineTask `json:"tasks"` +} + +// offlineTask +type OfflineTask struct { + Callback string `json:"callback"` + CreatedTime string `json:"created_time"` + FileID string `json:"file_id"` + FileName string `json:"file_name"` + FileSize string `json:"file_size"` + IconLink string `json:"icon_link"` + ID string `json:"id"` + Kind string `json:"kind"` + Message string `json:"message"` + Name string `json:"name"` + Params Params `json:"params"` + Phase string `json:"phase"` // PHASE_TYPE_RUNNING, PHASE_TYPE_ERROR, PHASE_TYPE_COMPLETE, PHASE_TYPE_PENDING + Progress int64 `json:"progress"` + ReferenceResource ReferenceResource `json:"reference_resource"` + Space string `json:"space"` + StatusSize int64 `json:"status_size"` + Statuses []string `json:"statuses"` + ThirdTaskID string `json:"third_task_id"` + Type string `json:"type"` + UpdatedTime string `json:"updated_time"` + UserID string `json:"user_id"` +} + +type Params struct { + Age string `json:"age"` + MIMEType *string `json:"mime_type,omitempty"` + PredictType string `json:"predict_type"` + URL string `json:"url"` +} + +type ReferenceResource struct { + Type string `json:"@type"` + Audit interface{} `json:"audit"` + Hash string `json:"hash"` + IconLink string `json:"icon_link"` + ID string `json:"id"` + Kind string `json:"kind"` + Medias []Media `json:"medias"` + MIMEType string `json:"mime_type"` + Name string `json:"name"` + Params map[string]interface{} `json:"params"` + ParentID string `json:"parent_id"` + Phase string `json:"phase"` + Size string `json:"size"` + Space string `json:"space"` + Starred bool `json:"starred"` + Tags []string `json:"tags"` + ThumbnailLink string `json:"thumbnail_link"` +} + +type ErrResp struct { + ErrorCode int64 `json:"error_code"` + ErrorMsg string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +func (e *ErrResp) IsError() bool { + return e.ErrorCode != 0 || e.ErrorMsg != "" || e.ErrorDescription != "" +} + +func (e *ErrResp) Error() string { + return fmt.Sprintf("ErrorCode: %d ,Error: %s ,ErrorDescription: %s ", e.ErrorCode, e.ErrorMsg, e.ErrorDescription) +} + +type CaptchaTokenRequest struct { + Action string `json:"action"` + CaptchaToken string `json:"captcha_token"` + ClientID string `json:"client_id"` + DeviceID string `json:"device_id"` + Meta map[string]string `json:"meta"` + RedirectUri string `json:"redirect_uri"` +} + +type CaptchaTokenResponse struct { + CaptchaToken string `json:"captcha_token"` + ExpiresIn int64 `json:"expires_in"` + Url string `json:"url"` +} diff --git a/drivers/pikpak_proxy/util.go b/drivers/pikpak_proxy/util.go new file mode 100644 index 0000000000000000000000000000000000000000..9e69023d90aadc87e59daa12da18cac1fc49515e --- /dev/null +++ b/drivers/pikpak_proxy/util.go @@ -0,0 +1,723 @@ +package PikPakProxy + +import ( + "bytes" + "crypto/md5" + "crypto/sha1" + "encoding/hex" + "fmt" + "io" + "net/http" + "path/filepath" + "regexp" + "strings" + "sync" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/aliyun/aliyun-oss-go-sdk/oss" + "github.com/go-resty/resty/v2" + jsoniter "github.com/json-iterator/go" + "github.com/pkg/errors" +) + +var AndroidAlgorithms = []string{ + "SOP04dGzk0TNO7t7t9ekDbAmx+eq0OI1ovEx", + "nVBjhYiND4hZ2NCGyV5beamIr7k6ifAsAbl", + "Ddjpt5B/Cit6EDq2a6cXgxY9lkEIOw4yC1GDF28KrA", + "VVCogcmSNIVvgV6U+AochorydiSymi68YVNGiz", + "u5ujk5sM62gpJOsB/1Gu/zsfgfZO", + "dXYIiBOAHZgzSruaQ2Nhrqc2im", + "z5jUTBSIpBN9g4qSJGlidNAutX6", + "KJE2oveZ34du/g1tiimm", +} + +var WebAlgorithms = []string{ + "C9qPpZLN8ucRTaTiUMWYS9cQvWOE", + "+r6CQVxjzJV6LCV", + "F", + "pFJRC", + "9WXYIDGrwTCz2OiVlgZa90qpECPD6olt", + "/750aCr4lm/Sly/c", + "RB+DT/gZCrbV", + "", + "CyLsf7hdkIRxRm215hl", + "7xHvLi2tOYP0Y92b", + "ZGTXXxu8E/MIWaEDB+Sm/", + "1UI3", + "E7fP5Pfijd+7K+t6Tg/NhuLq0eEUVChpJSkrKxpO", + "ihtqpG6FMt65+Xk+tWUH2", + "NhXXU9rg4XXdzo7u5o", +} + +var PCAlgorithms = []string{ + "KHBJ07an7ROXDoK7Db", + "G6n399rSWkl7WcQmw5rpQInurc1DkLmLJqE", + "JZD1A3M4x+jBFN62hkr7VDhkkZxb9g3rWqRZqFAAb", + "fQnw/AmSlbbI91Ik15gpddGgyU7U", + "/Dv9JdPYSj3sHiWjouR95NTQff", + "yGx2zuTjbWENZqecNI+edrQgqmZKP", + "ljrbSzdHLwbqcRn", + "lSHAsqCkGDGxQqqwrVu", + "TsWXI81fD1", + "vk7hBjawK/rOSrSWajtbMk95nfgf3", +} + +const ( + OSSUserAgent = "aliyun-sdk-android/2.9.13(Linux/Android 14/M2004j7ac;UKQ1.231108.001)" + OssSecurityTokenHeaderName = "X-OSS-Security-Token" + ThreadsNum = 10 +) + +const ( + AndroidClientID = "YNxT9w7GMdWvEOKa" + AndroidClientSecret = "dbw2OtmVEeuUvIptb1Coyg" + AndroidClientVersion = "1.53.2" + AndroidPackageName = "com.pikcloud.pikpak" + AndroidSdkVersion = "2.0.6.206003" + WebClientID = "YUMx5nI8ZU8Ap8pm" + WebClientSecret = "dbw2OtmVEeuUvIptb1Coyg" + WebClientVersion = "2.0.0" + WebPackageName = "mypikpak.com" + WebSdkVersion = "8.0.3" + PCClientID = "YvtoWO6GNHiuCl7x" + PCClientSecret = "1NIH5R1IEe2pAxZE3hv3uA" + PCClientVersion = "undefined" // 2.6.11.4955 + PCPackageName = "mypikpak.com" + PCSdkVersion = "8.0.3" +) + +func (d *PikPakProxy) login() error { + // 检查用户名和密码是否为空 + if d.Addition.Username == "" || d.Addition.Password == "" { + return errors.New("username or password is empty") + } + + url := "https://user.mypikpak.net/v1/auth/signin" + // 使用 用户填写的 CaptchaToken —————— (验证后的captcha_token) + if d.GetCaptchaToken() == "" { + if err := d.RefreshCaptchaTokenInLogin(GetAction(http.MethodPost, url), d.Username); err != nil { + return err + } + } + + if d.Addition.UseProxy { + if strings.HasSuffix(d.Addition.ProxyUrl, "/") { + url = d.Addition.ProxyUrl + url + } else { + url = d.Addition.ProxyUrl + "/" + url + } + + } + + var e ErrResp + res, err := base.RestyClient.SetRetryCount(1).R().SetError(&e).SetBody(base.Json{ + "captcha_token": d.GetCaptchaToken(), + "client_id": d.ClientID, + "client_secret": d.ClientSecret, + "username": d.Username, + "password": d.Password, + }).SetQueryParam("client_id", d.ClientID).Post(url) + if err != nil { + return err + } + if e.ErrorCode != 0 { + return &e + } + data := res.Body() + d.RefreshToken = jsoniter.Get(data, "refresh_token").ToString() + d.AccessToken = jsoniter.Get(data, "access_token").ToString() + d.Common.SetUserID(jsoniter.Get(data, "sub").ToString()) + return nil +} + +func (d *PikPakProxy) refreshToken(refreshToken string) error { + url := "https://user.mypikpak.net/v1/auth/token" + if d.Addition.UseProxy { + if strings.HasSuffix(d.Addition.ProxyUrl, "/") { + url = d.Addition.ProxyUrl + url + } else { + url = d.Addition.ProxyUrl + "/" + url + } + + } + var e ErrResp + res, err := base.RestyClient.SetRetryCount(1).R().SetError(&e). + SetHeader("user-agent", "").SetBody(base.Json{ + "client_id": d.ClientID, + "client_secret": d.ClientSecret, + "grant_type": "refresh_token", + "refresh_token": refreshToken, + }).SetQueryParam("client_id", d.ClientID).Post(url) + if err != nil { + d.Status = err.Error() + op.MustSaveDriverStorage(d) + return err + } + if e.ErrorCode != 0 { + if e.ErrorCode == 4126 { + // 1. 未填写 username 或 password + if d.Addition.Username == "" || d.Addition.Password == "" { + return errors.New("refresh_token invalid, please re-provide refresh_token") + } else { + // refresh_token invalid, re-login + return d.login() + } + } + d.Status = e.Error() + op.MustSaveDriverStorage(d) + return errors.New(e.Error()) + } + data := res.Body() + d.Status = "work" + d.RefreshToken = jsoniter.Get(data, "refresh_token").ToString() + d.AccessToken = jsoniter.Get(data, "access_token").ToString() + d.Common.SetUserID(jsoniter.Get(data, "sub").ToString()) + d.Addition.RefreshToken = d.RefreshToken + op.MustSaveDriverStorage(d) + return nil +} + +func (d *PikPakProxy) request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + //"Authorization": "Bearer " + d.AccessToken, + "User-Agent": d.GetUserAgent(), + "X-Device-ID": d.GetDeviceID(), + "X-Captcha-Token": d.GetCaptchaToken(), + }) + + if d.Addition.UseProxy { + if strings.HasSuffix(d.Addition.ProxyUrl, "/") { + url = d.Addition.ProxyUrl + url + } else { + url = d.Addition.ProxyUrl + "/" + url + } + + } + + if d.AccessToken != "" { + req.SetHeader("Authorization", "Bearer "+d.AccessToken) + } + + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e ErrResp + req.SetError(&e) + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + + switch e.ErrorCode { + case 0: + return res.Body(), nil + case 4122, 4121, 16: + // access_token 过期 + if err1 := d.refreshToken(d.RefreshToken); err1 != nil { + return nil, err1 + } + return d.request(url, method, callback, resp) + case 9: // 验证码token过期 + if err = d.RefreshCaptchaTokenAtLogin(GetAction(method, url), d.GetUserID()); err != nil { + return nil, err + } + return d.request(url, method, callback, resp) + case 10: // 操作频繁 + return nil, errors.New(e.ErrorDescription) + default: + return nil, errors.New(e.Error()) + } +} + +func (d *PikPakProxy) requestWithCaptchaToken(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + + if err := d.RefreshCaptchaTokenAtLogin(GetAction(method, url), d.Common.UserID); err != nil { + return nil, err + } + + data, err := d.request(url, method, func(req *resty.Request) { + req.SetHeaders(map[string]string{ + "User-Agent": d.GetUserAgent(), + "X-Device-ID": d.GetDeviceID(), + "X-Captcha-Token": d.GetCaptchaToken(), + }) + if callback != nil { + callback(req) + } + }, resp) + + errResp, ok := err.(*ErrResp) + + if !ok { + return nil, err + } + + switch errResp.ErrorCode { + case 0: + return data, nil + //case 4122, 4121, 10, 16: + // if d.refreshTokenFunc != nil { + // if err = xc.refreshTokenFunc(); err == nil { + // break + // } + // } + // return nil, err + case 9: // 验证码token过期 + if err = d.RefreshCaptchaTokenAtLogin(GetAction(method, url), d.Common.UserID); err != nil { + return nil, err + } + default: + return nil, err + } + return d.request(url, method, callback, resp) +} + +func (d *PikPakProxy) getFiles(id string) ([]File, error) { + res := make([]File, 0) + pageToken := "first" + for pageToken != "" { + if pageToken == "first" { + pageToken = "" + } + query := map[string]string{ + "parent_id": id, + "thumbnail_size": "SIZE_LARGE", + "with_audit": "true", + "limit": "100", + "filters": `{"phase":{"eq":"PHASE_TYPE_COMPLETE"},"trashed":{"eq":false}}`, + "page_token": pageToken, + } + var resp Files + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/files", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return nil, err + } + pageToken = resp.NextPageToken + res = append(res, resp.Files...) + } + return res, nil +} + +func GetAction(method string, url string) string { + urlpath := regexp.MustCompile(`://[^/]+((/[^/\s?#]+)*)`).FindStringSubmatch(url)[1] + return method + ":" + urlpath +} + +type Common struct { + client *resty.Client + CaptchaToken string + UserID string + // 必要值,签名相关 + ClientID string + ClientSecret string + ClientVersion string + PackageName string + Algorithms []string + DeviceID string + UserAgent string + // 验证码token刷新成功回调 + RefreshCTokenCk func(token string) + //代理设置 + UseProxy bool + //代理地址 + ProxyUrl string +} + +func generateDeviceSign(deviceID, packageName string) string { + + signatureBase := fmt.Sprintf("%s%s%s%s", deviceID, packageName, "1", "appkey") + + sha1Hash := sha1.New() + sha1Hash.Write([]byte(signatureBase)) + sha1Result := sha1Hash.Sum(nil) + + sha1String := hex.EncodeToString(sha1Result) + + md5Hash := md5.New() + md5Hash.Write([]byte(sha1String)) + md5Result := md5Hash.Sum(nil) + + md5String := hex.EncodeToString(md5Result) + + deviceSign := fmt.Sprintf("div101.%s%s", deviceID, md5String) + + return deviceSign +} + +func BuildCustomUserAgent(deviceID, clientID, appName, sdkVersion, clientVersion, packageName, userID string) string { + deviceSign := generateDeviceSign(deviceID, packageName) + var sb strings.Builder + + sb.WriteString(fmt.Sprintf("ANDROID-%s/%s ", appName, clientVersion)) + sb.WriteString("protocolVersion/200 ") + sb.WriteString("accesstype/ ") + sb.WriteString(fmt.Sprintf("clientid/%s ", clientID)) + sb.WriteString(fmt.Sprintf("clientversion/%s ", clientVersion)) + sb.WriteString("action_type/ ") + sb.WriteString("networktype/WIFI ") + sb.WriteString("sessionid/ ") + sb.WriteString(fmt.Sprintf("deviceid/%s ", deviceID)) + sb.WriteString("providername/NONE ") + sb.WriteString(fmt.Sprintf("devicesign/%s ", deviceSign)) + sb.WriteString("refresh_token/ ") + sb.WriteString(fmt.Sprintf("sdkversion/%s ", sdkVersion)) + sb.WriteString(fmt.Sprintf("datetime/%d ", time.Now().UnixMilli())) + sb.WriteString(fmt.Sprintf("usrno/%s ", userID)) + sb.WriteString(fmt.Sprintf("appname/android-%s ", appName)) + sb.WriteString(fmt.Sprintf("session_origin/ ")) + sb.WriteString(fmt.Sprintf("grant_type/ ")) + sb.WriteString(fmt.Sprintf("appid/ ")) + sb.WriteString(fmt.Sprintf("clientip/ ")) + sb.WriteString(fmt.Sprintf("devicename/Xiaomi_M2004j7ac ")) + sb.WriteString(fmt.Sprintf("osversion/13 ")) + sb.WriteString(fmt.Sprintf("platformversion/10 ")) + sb.WriteString(fmt.Sprintf("accessmode/ ")) + sb.WriteString(fmt.Sprintf("devicemodel/M2004J7AC ")) + + return sb.String() +} + +func (c *Common) SetDeviceID(deviceID string) { + c.DeviceID = deviceID +} + +func (c *Common) SetUserID(userID string) { + c.UserID = userID +} + +func (c *Common) SetUserAgent(userAgent string) { + c.UserAgent = userAgent +} + +func (c *Common) SetCaptchaToken(captchaToken string) { + c.CaptchaToken = captchaToken +} +func (c *Common) GetCaptchaToken() string { + return c.CaptchaToken +} + +func (c *Common) GetUserAgent() string { + return c.UserAgent +} + +func (c *Common) GetDeviceID() string { + return c.DeviceID +} + +func (c *Common) GetUserID() string { + return c.UserID +} + +// RefreshCaptchaTokenAtLogin 刷新验证码token(登录后) +func (d *PikPakProxy) RefreshCaptchaTokenAtLogin(action, userID string) error { + metas := map[string]string{ + "client_version": d.ClientVersion, + "package_name": d.PackageName, + "user_id": userID, + } + metas["timestamp"], metas["captcha_sign"] = d.Common.GetCaptchaSign() + return d.refreshCaptchaToken(action, metas) +} + +// RefreshCaptchaTokenInLogin 刷新验证码token(登录时) +func (d *PikPakProxy) RefreshCaptchaTokenInLogin(action, username string) error { + metas := make(map[string]string) + if ok, _ := regexp.MatchString(`\w+([-+.]\w+)*@\w+([-.]\w+)*\.\w+([-.]\w+)*`, username); ok { + metas["email"] = username + } else if len(username) >= 11 && len(username) <= 18 { + metas["phone_number"] = username + } else { + metas["username"] = username + } + return d.refreshCaptchaToken(action, metas) +} + +// GetCaptchaSign 获取验证码签名 +func (c *Common) GetCaptchaSign() (timestamp, sign string) { + timestamp = fmt.Sprint(time.Now().UnixMilli()) + str := fmt.Sprint(c.ClientID, c.ClientVersion, c.PackageName, c.DeviceID, timestamp) + for _, algorithm := range c.Algorithms { + str = utils.GetMD5EncodeStr(str + algorithm) + } + sign = "1." + str + return +} + +// refreshCaptchaToken 刷新CaptchaToken +func (d *PikPakProxy) refreshCaptchaToken(action string, metas map[string]string) error { + param := CaptchaTokenRequest{ + Action: action, + CaptchaToken: d.GetCaptchaToken(), + ClientID: d.ClientID, + DeviceID: d.GetDeviceID(), + Meta: metas, + RedirectUri: "xlaccsdk01://xbase.cloud/callback?state=harbor", + } + var e ErrResp + var resp CaptchaTokenResponse + _, err := d.request("https://user.mypikpak.net/v1/shield/captcha/init", http.MethodPost, func(req *resty.Request) { + req.SetError(&e).SetBody(param).SetQueryParam("client_id", d.ClientID) + }, &resp) + + if err != nil { + return err + } + + if e.IsError() { + return errors.New(e.Error()) + } + + if resp.Url != "" { + return fmt.Errorf(`need verify: Click Here`, resp.Url) + } + + if d.Common.RefreshCTokenCk != nil { + d.Common.RefreshCTokenCk(resp.CaptchaToken) + } + d.Common.SetCaptchaToken(resp.CaptchaToken) + return nil +} + +func (d *PikPakProxy) UploadByOSS(params *S3Params, stream model.FileStreamer, up driver.UpdateProgress) error { + ossClient, err := oss.New(params.Endpoint, params.AccessKeyID, params.AccessKeySecret) + if err != nil { + return err + } + bucket, err := ossClient.Bucket(params.Bucket) + if err != nil { + return err + } + + err = bucket.PutObject(params.Key, stream, OssOption(params)...) + if err != nil { + return err + } + return nil +} +func (d *PikPakProxy) UploadByMultipart(params *S3Params, fileSize int64, stream model.FileStreamer, up driver.UpdateProgress) error { + var ( + chunks []oss.FileChunk + parts []oss.UploadPart + imur oss.InitiateMultipartUploadResult + ossClient *oss.Client + bucket *oss.Bucket + err error + ) + + tmpF, err := stream.CacheFullInTempFile() + if err != nil { + return err + } + + if ossClient, err = oss.New(params.Endpoint, params.AccessKeyID, params.AccessKeySecret); err != nil { + return err + } + + if bucket, err = ossClient.Bucket(params.Bucket); err != nil { + return err + } + + ticker := time.NewTicker(time.Hour * 12) + defer ticker.Stop() + // 设置超时 + timeout := time.NewTimer(time.Hour * 24) + + if chunks, err = SplitFile(fileSize); err != nil { + return err + } + + if imur, err = bucket.InitiateMultipartUpload(params.Key, + oss.SetHeader(OssSecurityTokenHeaderName, params.SecurityToken), + oss.UserAgentHeader(OSSUserAgent), + ); err != nil { + return err + } + + wg := sync.WaitGroup{} + wg.Add(len(chunks)) + + chunksCh := make(chan oss.FileChunk) + errCh := make(chan error) + UploadedPartsCh := make(chan oss.UploadPart) + quit := make(chan struct{}) + + // producer + go chunksProducer(chunksCh, chunks) + go func() { + wg.Wait() + quit <- struct{}{} + }() + + // consumers + for i := 0; i < ThreadsNum; i++ { + go func(threadId int) { + defer func() { + if r := recover(); r != nil { + errCh <- fmt.Errorf("recovered in %v", r) + } + }() + for chunk := range chunksCh { + var part oss.UploadPart // 出现错误就继续尝试,共尝试3次 + for retry := 0; retry < 3; retry++ { + select { + case <-ticker.C: + errCh <- errors.Wrap(err, "ossToken 过期") + default: + } + + buf := make([]byte, chunk.Size) + if _, err = tmpF.ReadAt(buf, chunk.Offset); err != nil && !errors.Is(err, io.EOF) { + continue + } + + b := bytes.NewBuffer(buf) + if part, err = bucket.UploadPart(imur, b, chunk.Size, chunk.Number, OssOption(params)...); err == nil { + break + } + } + if err != nil { + errCh <- errors.Wrap(err, fmt.Sprintf("上传 %s 的第%d个分片时出现错误:%v", stream.GetName(), chunk.Number, err)) + } + UploadedPartsCh <- part + } + }(i) + } + + go func() { + for part := range UploadedPartsCh { + parts = append(parts, part) + wg.Done() + } + }() +LOOP: + for { + select { + case <-ticker.C: + // ossToken 过期 + return err + case <-quit: + break LOOP + case <-errCh: + return err + case <-timeout.C: + return fmt.Errorf("time out") + } + } + + // EOF错误是xml的Unmarshal导致的,响应其实是json格式,所以实际上上传是成功的 + if _, err = bucket.CompleteMultipartUpload(imur, parts, OssOption(params)...); err != nil && !errors.Is(err, io.EOF) { + // 当文件名含有 &< 这两个字符之一时响应的xml解析会出现错误,实际上上传是成功的 + if filename := filepath.Base(stream.GetName()); !strings.ContainsAny(filename, "&<") { + return err + } + } + return nil +} + +func chunksProducer(ch chan oss.FileChunk, chunks []oss.FileChunk) { + for _, chunk := range chunks { + ch <- chunk + } +} + +func SplitFile(fileSize int64) (chunks []oss.FileChunk, err error) { + for i := int64(1); i < 10; i++ { + if fileSize < i*utils.GB { // 文件大小小于iGB时分为i*100片 + if chunks, err = SplitFileByPartNum(fileSize, int(i*100)); err != nil { + return + } + break + } + } + if fileSize > 9*utils.GB { // 文件大小大于9GB时分为1000片 + if chunks, err = SplitFileByPartNum(fileSize, 1000); err != nil { + return + } + } + // 单个分片大小不能小于1MB + if chunks[0].Size < 1*utils.MB { + if chunks, err = SplitFileByPartSize(fileSize, 1*utils.MB); err != nil { + return + } + } + return +} + +// SplitFileByPartNum splits big file into parts by the num of parts. +// Split the file with specified parts count, returns the split result when error is nil. +func SplitFileByPartNum(fileSize int64, chunkNum int) ([]oss.FileChunk, error) { + if chunkNum <= 0 || chunkNum > 10000 { + return nil, errors.New("chunkNum invalid") + } + + if int64(chunkNum) > fileSize { + return nil, errors.New("oss: chunkNum invalid") + } + + var chunks []oss.FileChunk + chunk := oss.FileChunk{} + chunkN := (int64)(chunkNum) + for i := int64(0); i < chunkN; i++ { + chunk.Number = int(i + 1) + chunk.Offset = i * (fileSize / chunkN) + if i == chunkN-1 { + chunk.Size = fileSize/chunkN + fileSize%chunkN + } else { + chunk.Size = fileSize / chunkN + } + chunks = append(chunks, chunk) + } + + return chunks, nil +} + +// SplitFileByPartSize splits big file into parts by the size of parts. +// Splits the file by the part size. Returns the FileChunk when error is nil. +func SplitFileByPartSize(fileSize int64, chunkSize int64) ([]oss.FileChunk, error) { + if chunkSize <= 0 { + return nil, errors.New("chunkSize invalid") + } + + chunkN := fileSize / chunkSize + if chunkN >= 10000 { + return nil, errors.New("Too many parts, please increase part size") + } + + var chunks []oss.FileChunk + chunk := oss.FileChunk{} + for i := int64(0); i < chunkN; i++ { + chunk.Number = int(i + 1) + chunk.Offset = i * chunkSize + chunk.Size = chunkSize + chunks = append(chunks, chunk) + } + + if fileSize%chunkSize > 0 { + chunk.Number = len(chunks) + 1 + chunk.Offset = int64(len(chunks)) * chunkSize + chunk.Size = fileSize % chunkSize + chunks = append(chunks, chunk) + } + + return chunks, nil +} + +// OssOption get options +func OssOption(params *S3Params) []oss.Option { + options := []oss.Option{ + oss.SetHeader(OssSecurityTokenHeaderName, params.SecurityToken), + oss.UserAgentHeader(OSSUserAgent), + } + return options +} diff --git a/drivers/pikpak_share/driver.go b/drivers/pikpak_share/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..7ec9cb28dc6c44249266c1a14bb9e37f873079b5 --- /dev/null +++ b/drivers/pikpak_share/driver.go @@ -0,0 +1,143 @@ +package pikpak_share + +import ( + "context" + "github.com/alist-org/alist/v3/internal/op" + "net/http" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +type PikPakShare struct { + model.Storage + Addition + *Common + PassCodeToken string +} + +func (d *PikPakShare) Config() driver.Config { + return config +} + +func (d *PikPakShare) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *PikPakShare) Init(ctx context.Context) error { + if d.Common == nil { + d.Common = &Common{ + DeviceID: utils.GetMD5EncodeStr(d.Addition.ShareId + d.Addition.SharePwd + time.Now().String()), + UserAgent: "", + RefreshCTokenCk: func(token string) { + d.Common.CaptchaToken = token + op.MustSaveDriverStorage(d) + }, + UseProxy: d.Addition.UseProxy, + ProxyUrl: d.Addition.ProxyUrl, + } + } + + + if d.Addition.DeviceID != "" { + d.SetDeviceID(d.Addition.DeviceID) + } else { + d.Addition.DeviceID = d.Common.DeviceID + op.MustSaveDriverStorage(d) + } + + if d.Platform == "android" { + d.ClientID = AndroidClientID + d.ClientSecret = AndroidClientSecret + d.ClientVersion = AndroidClientVersion + d.PackageName = AndroidPackageName + d.Algorithms = AndroidAlgorithms + d.UserAgent = BuildCustomUserAgent(d.GetDeviceID(), AndroidClientID, AndroidPackageName, AndroidSdkVersion, AndroidClientVersion, AndroidPackageName, "") + } else if d.Platform == "web" { + d.ClientID = WebClientID + d.ClientSecret = WebClientSecret + d.ClientVersion = WebClientVersion + d.PackageName = WebPackageName + d.Algorithms = WebAlgorithms + d.UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36" + } else if d.Platform == "pc" { + d.ClientID = PCClientID + d.ClientSecret = PCClientSecret + d.ClientVersion = PCClientVersion + d.PackageName = PCPackageName + d.Algorithms = PCAlgorithms + d.UserAgent = "MainWindow Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) PikPak/2.6.11.4955 Chrome/100.0.4896.160 Electron/18.3.15 Safari/537.36" + } + + // 获取CaptchaToken + err := d.RefreshCaptchaToken(GetAction(http.MethodGet, "https://api-drive.mypikpak.net/drive/v1/share:batch_file_info"), "") + if err != nil { + return err + } + + if d.SharePwd != "" { + return d.getSharePassToken() + } + + return nil +} + +func (d *PikPakShare) Drop(ctx context.Context) error { + return nil +} + +func (d *PikPakShare) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *PikPakShare) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var resp ShareResp + query := map[string]string{ + "share_id": d.ShareId, + "file_id": file.GetID(), + "pass_code_token": d.PassCodeToken, + } + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/share/file_info", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return nil, err + } + + downloadUrl := resp.FileInfo.WebContentLink + if downloadUrl == "" && len(resp.FileInfo.Medias) > 0 { + // 使用转码后的链接 + if d.Addition.UseTransCodingAddress && len(resp.FileInfo.Medias) > 1 { + downloadUrl = resp.FileInfo.Medias[1].Link.Url + } else { + downloadUrl = resp.FileInfo.Medias[0].Link.Url + } + + } + + if d.Addition.UseProxy { + if strings.HasSuffix(d.Addition.ProxyUrl, "/") { + downloadUrl = d.Addition.ProxyUrl + downloadUrl + } else { + downloadUrl = d.Addition.ProxyUrl + "/" + downloadUrl + } + + } + + + return &model.Link{ + URL: downloadUrl, + }, nil +} + +var _ driver.Driver = (*PikPakShare)(nil) diff --git a/drivers/pikpak_share/meta.go b/drivers/pikpak_share/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..e6266ed62f4adf03fe64565ef2f793f03fa0851e --- /dev/null +++ b/drivers/pikpak_share/meta.go @@ -0,0 +1,32 @@ +package pikpak_share + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootID + ShareId string `json:"share_id" required:"true"` + SharePwd string `json:"share_pwd"` + Platform string `json:"platform" default:"web" required:"true" type:"select" options:"android,web,pc"` + DeviceID string `json:"device_id" required:"false" default:""` + UseTransCodingAddress bool `json:"use_transcoding_address" required:"true" default:"false"` + //是否使用代理 + UseProxy bool `json:"use_proxy"` + //下代理地址 + ProxyUrl string `json:"proxy_url" default:""` +} + +var config = driver.Config{ + Name: "PikPakShare", + LocalSort: true, + NoUpload: true, + DefaultRoot: "", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &PikPakShare{} + }) +} diff --git a/drivers/pikpak_share/types.go b/drivers/pikpak_share/types.go new file mode 100644 index 0000000000000000000000000000000000000000..78ea2ff8bbc26f9d8ee257246f023c0f92643f5e --- /dev/null +++ b/drivers/pikpak_share/types.go @@ -0,0 +1,105 @@ +package pikpak_share + +import ( + "fmt" + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type ShareResp struct { + ShareStatus string `json:"share_status"` + ShareStatusText string `json:"share_status_text"` + FileInfo File `json:"file_info"` + Files []File `json:"files"` + NextPageToken string `json:"next_page_token"` + PassCodeToken string `json:"pass_code_token"` +} + +type File struct { + Id string `json:"id"` + ShareId string `json:"share_id"` + Kind string `json:"kind"` + Name string `json:"name"` + ModifiedTime time.Time `json:"modified_time"` + Size string `json:"size"` + ThumbnailLink string `json:"thumbnail_link"` + WebContentLink string `json:"web_content_link"` + Medias []Media `json:"medias"` +} + +func fileToObj(f File) *model.ObjThumb { + size, _ := strconv.ParseInt(f.Size, 10, 64) + return &model.ObjThumb{ + Object: model.Object{ + ID: f.Id, + Name: f.Name, + Size: size, + Modified: f.ModifiedTime, + IsFolder: f.Kind == "drive#folder", + }, + Thumbnail: model.Thumbnail{ + Thumbnail: f.ThumbnailLink, + }, + } +} + +type Media struct { + MediaId string `json:"media_id"` + MediaName string `json:"media_name"` + Video struct { + Height int `json:"height"` + Width int `json:"width"` + Duration int `json:"duration"` + BitRate int `json:"bit_rate"` + FrameRate int `json:"frame_rate"` + VideoCodec string `json:"video_codec"` + AudioCodec string `json:"audio_codec"` + VideoType string `json:"video_type"` + } `json:"video"` + Link struct { + Url string `json:"url"` + Token string `json:"token"` + Expire time.Time `json:"expire"` + } `json:"link"` + NeedMoreQuota bool `json:"need_more_quota"` + VipTypes []interface{} `json:"vip_types"` + RedirectLink string `json:"redirect_link"` + IconLink string `json:"icon_link"` + IsDefault bool `json:"is_default"` + Priority int `json:"priority"` + IsOrigin bool `json:"is_origin"` + ResolutionName string `json:"resolution_name"` + IsVisible bool `json:"is_visible"` + Category string `json:"category"` +} + +type CaptchaTokenRequest struct { + Action string `json:"action"` + CaptchaToken string `json:"captcha_token"` + ClientID string `json:"client_id"` + DeviceID string `json:"device_id"` + Meta map[string]string `json:"meta"` + RedirectUri string `json:"redirect_uri"` +} + +type CaptchaTokenResponse struct { + CaptchaToken string `json:"captcha_token"` + ExpiresIn int64 `json:"expires_in"` + Url string `json:"url"` +} + +type ErrResp struct { + ErrorCode int64 `json:"error_code"` + ErrorMsg string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +func (e *ErrResp) IsError() bool { + return e.ErrorCode != 0 || e.ErrorMsg != "" || e.ErrorDescription != "" +} + +func (e *ErrResp) Error() string { + return fmt.Sprintf("ErrorCode: %d ,Error: %s ,ErrorDescription: %s ", e.ErrorCode, e.ErrorMsg, e.ErrorDescription) +} diff --git a/drivers/pikpak_share/util.go b/drivers/pikpak_share/util.go new file mode 100644 index 0000000000000000000000000000000000000000..e49cb3aec478bb4c2ec18e1d586fbd95c7a85072 --- /dev/null +++ b/drivers/pikpak_share/util.go @@ -0,0 +1,343 @@ +package pikpak_share + +import ( + "crypto/md5" + "crypto/sha1" + "encoding/hex" + "errors" + "fmt" + "github.com/alist-org/alist/v3/pkg/utils" + "net/http" + "regexp" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/go-resty/resty/v2" +) + +var AndroidAlgorithms = []string{ + "SOP04dGzk0TNO7t7t9ekDbAmx+eq0OI1ovEx", + "nVBjhYiND4hZ2NCGyV5beamIr7k6ifAsAbl", + "Ddjpt5B/Cit6EDq2a6cXgxY9lkEIOw4yC1GDF28KrA", + "VVCogcmSNIVvgV6U+AochorydiSymi68YVNGiz", + "u5ujk5sM62gpJOsB/1Gu/zsfgfZO", + "dXYIiBOAHZgzSruaQ2Nhrqc2im", + "z5jUTBSIpBN9g4qSJGlidNAutX6", + "KJE2oveZ34du/g1tiimm", +} + +var WebAlgorithms = []string{ + "C9qPpZLN8ucRTaTiUMWYS9cQvWOE", + "+r6CQVxjzJV6LCV", + "F", + "pFJRC", + "9WXYIDGrwTCz2OiVlgZa90qpECPD6olt", + "/750aCr4lm/Sly/c", + "RB+DT/gZCrbV", + "", + "CyLsf7hdkIRxRm215hl", + "7xHvLi2tOYP0Y92b", + "ZGTXXxu8E/MIWaEDB+Sm/", + "1UI3", + "E7fP5Pfijd+7K+t6Tg/NhuLq0eEUVChpJSkrKxpO", + "ihtqpG6FMt65+Xk+tWUH2", + "NhXXU9rg4XXdzo7u5o", +} + +var PCAlgorithms = []string{ + "KHBJ07an7ROXDoK7Db", + "G6n399rSWkl7WcQmw5rpQInurc1DkLmLJqE", + "JZD1A3M4x+jBFN62hkr7VDhkkZxb9g3rWqRZqFAAb", + "fQnw/AmSlbbI91Ik15gpddGgyU7U", + "/Dv9JdPYSj3sHiWjouR95NTQff", + "yGx2zuTjbWENZqecNI+edrQgqmZKP", + "ljrbSzdHLwbqcRn", + "lSHAsqCkGDGxQqqwrVu", + "TsWXI81fD1", + "vk7hBjawK/rOSrSWajtbMk95nfgf3", +} + +const ( + AndroidClientID = "YNxT9w7GMdWvEOKa" + AndroidClientSecret = "dbw2OtmVEeuUvIptb1Coyg" + AndroidClientVersion = "1.53.2" + AndroidPackageName = "com.pikcloud.pikpak" + AndroidSdkVersion = "2.0.6.206003" + WebClientID = "YUMx5nI8ZU8Ap8pm" + WebClientSecret = "dbw2OtmVEeuUvIptb1Coyg" + WebClientVersion = "2.0.0" + WebPackageName = "mypikpak.com" + WebSdkVersion = "8.0.3" + PCClientID = "YvtoWO6GNHiuCl7x" + PCClientSecret = "1NIH5R1IEe2pAxZE3hv3uA" + PCClientVersion = "undefined" // 2.6.11.4955 + PCPackageName = "mypikpak.com" + PCSdkVersion = "8.0.3" +) + +func (d *PikPakShare) request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "User-Agent": d.GetUserAgent(), + "X-Client-ID": d.GetClientID(), + "X-Device-ID": d.GetDeviceID(), + "X-Captcha-Token": d.GetCaptchaToken(), + }) + + if d.Addition.UseProxy { + if strings.HasSuffix(d.Addition.ProxyUrl, "/") { + url = d.Addition.ProxyUrl + url + } else { + url = d.Addition.ProxyUrl + "/" + url + } + + } + + + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e ErrResp + req.SetError(&e) + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + switch e.ErrorCode { + case 0: + return res.Body(), nil + case 9: // 验证码token过期 + if err = d.RefreshCaptchaToken(GetAction(method, url), ""); err != nil { + return nil, err + } + return d.request(url, method, callback, resp) + case 10: // 操作频繁 + return nil, errors.New(e.ErrorDescription) + default: + return nil, errors.New(e.Error()) + } +} + +func (d *PikPakShare) getSharePassToken() error { + query := map[string]string{ + "share_id": d.ShareId, + "pass_code": d.SharePwd, + "thumbnail_size": "SIZE_LARGE", + "limit": "100", + } + var resp ShareResp + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/share", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return err + } + d.PassCodeToken = resp.PassCodeToken + return nil +} + +func (d *PikPakShare) getFiles(id string) ([]File, error) { + res := make([]File, 0) + pageToken := "first" + for pageToken != "" { + if pageToken == "first" { + pageToken = "" + } + query := map[string]string{ + "parent_id": id, + "share_id": d.ShareId, + "thumbnail_size": "SIZE_LARGE", + "with_audit": "true", + "limit": "100", + "filters": `{"phase":{"eq":"PHASE_TYPE_COMPLETE"},"trashed":{"eq":false}}`, + "page_token": pageToken, + "pass_code_token": d.PassCodeToken, + } + var resp ShareResp + _, err := d.request("https://api-drive.mypikpak.net/drive/v1/share/detail", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return nil, err + } + if resp.ShareStatus != "OK" { + if resp.ShareStatus == "PASS_CODE_EMPTY" || resp.ShareStatus == "PASS_CODE_ERROR" { + err = d.getSharePassToken() + if err != nil { + return nil, err + } + return d.getFiles(id) + } + return nil, errors.New(resp.ShareStatusText) + } + pageToken = resp.NextPageToken + res = append(res, resp.Files...) + } + return res, nil +} + +func GetAction(method string, url string) string { + urlpath := regexp.MustCompile(`://[^/]+((/[^/\s?#]+)*)`).FindStringSubmatch(url)[1] + return method + ":" + urlpath +} + +type Common struct { + client *resty.Client + CaptchaToken string + // 必要值,签名相关 + ClientID string + ClientSecret string + ClientVersion string + PackageName string + Algorithms []string + DeviceID string + UserAgent string + // 验证码token刷新成功回调 + RefreshCTokenCk func(token string) + //代理设置 + UseProxy bool + //代理地址 + ProxyUrl string +} + +func (c *Common) SetUserAgent(userAgent string) { + c.UserAgent = userAgent +} + +func (c *Common) SetCaptchaToken(captchaToken string) { + c.CaptchaToken = captchaToken +} + +func (c *Common) SetDeviceID(deviceID string) { + c.DeviceID = deviceID +} + +func (c *Common) GetCaptchaToken() string { + return c.CaptchaToken +} + +func (c *Common) GetClientID() string { + return c.ClientID +} + +func (c *Common) GetUserAgent() string { + return c.UserAgent +} + +func (c *Common) GetDeviceID() string { + return c.DeviceID +} + +func generateDeviceSign(deviceID, packageName string) string { + + signatureBase := fmt.Sprintf("%s%s%s%s", deviceID, packageName, "1", "appkey") + + sha1Hash := sha1.New() + sha1Hash.Write([]byte(signatureBase)) + sha1Result := sha1Hash.Sum(nil) + + sha1String := hex.EncodeToString(sha1Result) + + md5Hash := md5.New() + md5Hash.Write([]byte(sha1String)) + md5Result := md5Hash.Sum(nil) + + md5String := hex.EncodeToString(md5Result) + + deviceSign := fmt.Sprintf("div101.%s%s", deviceID, md5String) + + return deviceSign +} + +func BuildCustomUserAgent(deviceID, clientID, appName, sdkVersion, clientVersion, packageName, userID string) string { + deviceSign := generateDeviceSign(deviceID, packageName) + var sb strings.Builder + + sb.WriteString(fmt.Sprintf("ANDROID-%s/%s ", appName, clientVersion)) + sb.WriteString("protocolVersion/200 ") + sb.WriteString("accesstype/ ") + sb.WriteString(fmt.Sprintf("clientid/%s ", clientID)) + sb.WriteString(fmt.Sprintf("clientversion/%s ", clientVersion)) + sb.WriteString("action_type/ ") + sb.WriteString("networktype/WIFI ") + sb.WriteString("sessionid/ ") + sb.WriteString(fmt.Sprintf("deviceid/%s ", deviceID)) + sb.WriteString("providername/NONE ") + sb.WriteString(fmt.Sprintf("devicesign/%s ", deviceSign)) + sb.WriteString("refresh_token/ ") + sb.WriteString(fmt.Sprintf("sdkversion/%s ", sdkVersion)) + sb.WriteString(fmt.Sprintf("datetime/%d ", time.Now().UnixMilli())) + sb.WriteString(fmt.Sprintf("usrno/%s ", userID)) + sb.WriteString(fmt.Sprintf("appname/android-%s ", appName)) + sb.WriteString(fmt.Sprintf("session_origin/ ")) + sb.WriteString(fmt.Sprintf("grant_type/ ")) + sb.WriteString(fmt.Sprintf("appid/ ")) + sb.WriteString(fmt.Sprintf("clientip/ ")) + sb.WriteString(fmt.Sprintf("devicename/Xiaomi_M2004j7ac ")) + sb.WriteString(fmt.Sprintf("osversion/13 ")) + sb.WriteString(fmt.Sprintf("platformversion/10 ")) + sb.WriteString(fmt.Sprintf("accessmode/ ")) + sb.WriteString(fmt.Sprintf("devicemodel/M2004J7AC ")) + + return sb.String() +} + +// RefreshCaptchaToken 刷新验证码token +func (d *PikPakShare) RefreshCaptchaToken(action, userID string) error { + metas := map[string]string{ + "client_version": d.ClientVersion, + "package_name": d.PackageName, + "user_id": userID, + } + metas["timestamp"], metas["captcha_sign"] = d.Common.GetCaptchaSign() + return d.refreshCaptchaToken(action, metas) +} + +// GetCaptchaSign 获取验证码签名 +func (c *Common) GetCaptchaSign() (timestamp, sign string) { + timestamp = fmt.Sprint(time.Now().UnixMilli()) + str := fmt.Sprint(c.ClientID, c.ClientVersion, c.PackageName, c.DeviceID, timestamp) + for _, algorithm := range c.Algorithms { + str = utils.GetMD5EncodeStr(str + algorithm) + } + sign = "1." + str + return +} + +// refreshCaptchaToken 刷新CaptchaToken +func (d *PikPakShare) refreshCaptchaToken(action string, metas map[string]string) error { + param := CaptchaTokenRequest{ + Action: action, + CaptchaToken: d.GetCaptchaToken(), + ClientID: d.ClientID, + DeviceID: d.GetDeviceID(), + Meta: metas, + } + var e ErrResp + var resp CaptchaTokenResponse + _, err := d.request("https://user.mypikpak.net/v1/shield/captcha/init", http.MethodPost, func(req *resty.Request) { + req.SetError(&e).SetBody(param) + }, &resp) + + if err != nil { + return err + } + + if e.IsError() { + return errors.New(e.Error()) + } + + //if resp.Url != "" { + // return fmt.Errorf(`need verify: Click Here`, resp.Url) + //} + + if d.Common.RefreshCTokenCk != nil { + d.Common.RefreshCTokenCk(resp.CaptchaToken) + } + d.Common.SetCaptchaToken(resp.CaptchaToken) + return nil +} diff --git a/drivers/quark_uc/driver.go b/drivers/quark_uc/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..8674fbab26fe95ed2d349ce540cd52cb3f5646e0 --- /dev/null +++ b/drivers/quark_uc/driver.go @@ -0,0 +1,221 @@ +package quark + +import ( + "context" + "crypto/md5" + "crypto/sha1" + "encoding/hex" + "io" + "net/http" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +type QuarkOrUC struct { + model.Storage + Addition + config driver.Config + conf Conf +} + +func (d *QuarkOrUC) Config() driver.Config { + return d.config +} + +func (d *QuarkOrUC) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *QuarkOrUC) Init(ctx context.Context) error { + _, err := d.request("/config", http.MethodGet, nil, nil) + return err +} + +func (d *QuarkOrUC) Drop(ctx context.Context) error { + return nil +} + +func (d *QuarkOrUC) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.GetFiles(dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *QuarkOrUC) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + data := base.Json{ + "fids": []string{file.GetID()}, + } + var resp DownResp + ua := d.conf.ua + _, err := d.request("/file/download", http.MethodPost, func(req *resty.Request) { + req.SetHeader("User-Agent", ua). + SetBody(data) + }, &resp) + if err != nil { + return nil, err + } + + return &model.Link{ + URL: resp.Data[0].DownloadUrl, + Header: http.Header{ + "Cookie": []string{d.Cookie}, + "Referer": []string{d.conf.referer}, + "User-Agent": []string{ua}, + }, + Concurrency: 2, + PartSize: 10 * utils.MB, + }, nil +} + +func (d *QuarkOrUC) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + data := base.Json{ + "dir_init_lock": false, + "dir_path": "", + "file_name": dirName, + "pdir_fid": parentDir.GetID(), + } + _, err := d.request("/file", http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + if err == nil { + time.Sleep(time.Second) + } + return err +} + +func (d *QuarkOrUC) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + data := base.Json{ + "action_type": 1, + "exclude_fids": []string{}, + "filelist": []string{srcObj.GetID()}, + "to_pdir_fid": dstDir.GetID(), + } + _, err := d.request("/file/move", http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *QuarkOrUC) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + data := base.Json{ + "fid": srcObj.GetID(), + "file_name": newName, + } + _, err := d.request("/file/rename", http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *QuarkOrUC) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotSupport +} + +func (d *QuarkOrUC) Remove(ctx context.Context, obj model.Obj) error { + data := base.Json{ + "action_type": 1, + "exclude_fids": []string{}, + "filelist": []string{obj.GetID()}, + } + _, err := d.request("/file/delete", http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *QuarkOrUC) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + tempFile, err := stream.CacheFullInTempFile() + if err != nil { + return err + } + defer func() { + _ = tempFile.Close() + }() + m := md5.New() + _, err = utils.CopyWithBuffer(m, tempFile) + if err != nil { + return err + } + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return err + } + md5Str := hex.EncodeToString(m.Sum(nil)) + s := sha1.New() + _, err = utils.CopyWithBuffer(s, tempFile) + if err != nil { + return err + } + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return err + } + sha1Str := hex.EncodeToString(s.Sum(nil)) + // pre + pre, err := d.upPre(stream, dstDir.GetID()) + if err != nil { + return err + } + log.Debugln("hash: ", md5Str, sha1Str) + // hash + finish, err := d.upHash(md5Str, sha1Str, pre.Data.TaskId) + if err != nil { + return err + } + if finish { + return nil + } + // part up + partSize := pre.Metadata.PartSize + var bytes []byte + md5s := make([]string, 0) + defaultBytes := make([]byte, partSize) + total := stream.GetSize() + left := total + partNumber := 1 + for left > 0 { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + if left > int64(partSize) { + bytes = defaultBytes + } else { + bytes = make([]byte, left) + } + _, err := io.ReadFull(tempFile, bytes) + if err != nil { + return err + } + left -= int64(len(bytes)) + log.Debugf("left: %d", left) + m, err := d.upPart(ctx, pre, stream.GetMimetype(), partNumber, bytes) + //m, err := driver.UpPart(pre, file.GetMIMEType(), partNumber, bytes, account, md5Str, sha1Str) + if err != nil { + return err + } + if m == "finish" { + return nil + } + md5s = append(md5s, m) + partNumber++ + up(100 * float64(total-left) / float64(total)) + } + err = d.upCommit(pre, md5s) + if err != nil { + return err + } + return d.upFinish(pre) +} + +var _ driver.Driver = (*QuarkOrUC)(nil) diff --git a/drivers/quark_uc/meta.go b/drivers/quark_uc/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..f3acfe88562aba3bb71cb59e8bfeab09985b345c --- /dev/null +++ b/drivers/quark_uc/meta.go @@ -0,0 +1,55 @@ +package quark + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + Cookie string `json:"cookie" required:"true"` + driver.RootID + OrderBy string `json:"order_by" type:"select" options:"none,file_type,file_name,updated_at" default:"none"` + OrderDirection string `json:"order_direction" type:"select" options:"asc,desc" default:"asc"` +} + +type Conf struct { + ua string + referer string + api string + pr string +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &QuarkOrUC{ + config: driver.Config{ + Name: "Quark", + OnlyLocal: true, + DefaultRoot: "0", + NoOverwriteUpload: true, + }, + conf: Conf{ + ua: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) quark-cloud-drive/2.5.20 Chrome/100.0.4896.160 Electron/18.3.5.4-b478491100 Safari/537.36 Channel/pckk_other_ch", + referer: "https://pan.quark.cn", + api: "https://drive.quark.cn/1/clouddrive", + pr: "ucpro", + }, + } + }) + op.RegisterDriver(func() driver.Driver { + return &QuarkOrUC{ + config: driver.Config{ + Name: "UC", + OnlyLocal: true, + DefaultRoot: "0", + NoOverwriteUpload: true, + }, + conf: Conf{ + ua: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) uc-cloud-drive/2.5.20 Chrome/100.0.4896.160 Electron/18.3.5.4-b478491100 Safari/537.36 Channel/pckk_other_ch", + referer: "https://drive.uc.cn", + api: "https://pc-api.uc.cn/1/clouddrive", + pr: "UCBrowser", + }, + } + }) +} diff --git a/drivers/quark_uc/types.go b/drivers/quark_uc/types.go new file mode 100644 index 0000000000000000000000000000000000000000..afbdb3eff89360cf28e5d10ae036102554ba72b4 --- /dev/null +++ b/drivers/quark_uc/types.go @@ -0,0 +1,150 @@ +package quark + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type Resp struct { + Status int `json:"status"` + Code int `json:"code"` + Message string `json:"message"` + //ReqId string `json:"req_id"` + //Timestamp int `json:"timestamp"` +} + +type File struct { + Fid string `json:"fid"` + FileName string `json:"file_name"` + //PdirFid string `json:"pdir_fid"` + //Category int `json:"category"` + //FileType int `json:"file_type"` + Size int64 `json:"size"` + //FormatType string `json:"format_type"` + //Status int `json:"status"` + //Tags string `json:"tags,omitempty"` + //LCreatedAt int64 `json:"l_created_at"` + LUpdatedAt int64 `json:"l_updated_at"` + //NameSpace int `json:"name_space"` + //IncludeItems int `json:"include_items,omitempty"` + //RiskType int `json:"risk_type"` + //BackupSign int `json:"backup_sign"` + //Duration int `json:"duration"` + //FileSource string `json:"file_source"` + File bool `json:"file"` + //CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + //PrivateExtra struct {} `json:"_private_extra"` + //ObjCategory string `json:"obj_category,omitempty"` + //Thumbnail string `json:"thumbnail,omitempty"` +} + +func fileToObj(f File) *model.Object { + return &model.Object{ + ID: f.Fid, + Name: f.FileName, + Size: f.Size, + Modified: time.UnixMilli(f.UpdatedAt), + IsFolder: !f.File, + } +} + +type SortResp struct { + Resp + Data struct { + List []File `json:"list"` + } `json:"data"` + Metadata struct { + Size int `json:"_size"` + Page int `json:"_page"` + Count int `json:"_count"` + Total int `json:"_total"` + Way string `json:"way"` + } `json:"metadata"` +} + +type DownResp struct { + Resp + Data []struct { + //Fid string `json:"fid"` + //FileName string `json:"file_name"` + //PdirFid string `json:"pdir_fid"` + //Category int `json:"category"` + //FileType int `json:"file_type"` + //Size int `json:"size"` + //FormatType string `json:"format_type"` + //Status int `json:"status"` + //Tags string `json:"tags"` + //LCreatedAt int64 `json:"l_created_at"` + //LUpdatedAt int64 `json:"l_updated_at"` + //NameSpace int `json:"name_space"` + //Thumbnail string `json:"thumbnail"` + DownloadUrl string `json:"download_url"` + //Md5 string `json:"md5"` + //RiskType int `json:"risk_type"` + //RangeSize int `json:"range_size"` + //BackupSign int `json:"backup_sign"` + //ObjCategory string `json:"obj_category"` + //Duration int `json:"duration"` + //FileSource string `json:"file_source"` + //File bool `json:"file"` + //CreatedAt int64 `json:"created_at"` + //UpdatedAt int64 `json:"updated_at"` + //PrivateExtra struct { + //} `json:"_private_extra"` + } `json:"data"` + //Metadata struct { + // Acc2 string `json:"acc2"` + // Acc1 string `json:"acc1"` + //} `json:"metadata"` +} + +type UpPreResp struct { + Resp + Data struct { + TaskId string `json:"task_id"` + Finish bool `json:"finish"` + UploadId string `json:"upload_id"` + ObjKey string `json:"obj_key"` + UploadUrl string `json:"upload_url"` + Fid string `json:"fid"` + Bucket string `json:"bucket"` + Callback struct { + CallbackUrl string `json:"callbackUrl"` + CallbackBody string `json:"callbackBody"` + } `json:"callback"` + FormatType string `json:"format_type"` + Size int `json:"size"` + AuthInfo string `json:"auth_info"` + } `json:"data"` + Metadata struct { + PartThread int `json:"part_thread"` + Acc2 string `json:"acc2"` + Acc1 string `json:"acc1"` + PartSize int `json:"part_size"` // 分片大小 + } `json:"metadata"` +} + +type HashResp struct { + Resp + Data struct { + Finish bool `json:"finish"` + Fid string `json:"fid"` + Thumbnail string `json:"thumbnail"` + FormatType string `json:"format_type"` + } `json:"data"` + Metadata struct { + } `json:"metadata"` +} + +type UpAuthResp struct { + Resp + Data struct { + AuthKey string `json:"auth_key"` + Speed int `json:"speed"` + Headers []interface{} `json:"headers"` + } `json:"data"` + Metadata struct { + } `json:"metadata"` +} diff --git a/drivers/quark_uc/util.go b/drivers/quark_uc/util.go new file mode 100644 index 0000000000000000000000000000000000000000..df27af6714f8e6d912d87193263f88cf84c5fc66 --- /dev/null +++ b/drivers/quark_uc/util.go @@ -0,0 +1,252 @@ +package quark + +import ( + "context" + "crypto/md5" + "encoding/base64" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/cookie" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +// do others that not defined in Driver interface + +func (d *QuarkOrUC) request(pathname string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + u := d.conf.api + pathname + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "Cookie": d.Cookie, + "Accept": "application/json, text/plain, */*", + "Referer": d.conf.referer, + }) + req.SetQueryParam("pr", d.conf.pr) + req.SetQueryParam("fr", "pc") + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e Resp + req.SetError(&e) + res, err := req.Execute(method, u) + if err != nil { + return nil, err + } + __puus := cookie.GetCookie(res.Cookies(), "__puus") + if __puus != nil { + d.Cookie = cookie.SetStr(d.Cookie, "__puus", __puus.Value) + op.MustSaveDriverStorage(d) + } + if e.Status >= 400 || e.Code != 0 { + return nil, errors.New(e.Message) + } + return res.Body(), nil +} + +func (d *QuarkOrUC) GetFiles(parent string) ([]File, error) { + files := make([]File, 0) + page := 1 + size := 100 + query := map[string]string{ + "pdir_fid": parent, + "_size": strconv.Itoa(size), + "_fetch_total": "1", + } + if d.OrderBy != "none" { + query["_sort"] = "file_type:asc," + d.OrderBy + ":" + d.OrderDirection + } + for { + query["_page"] = strconv.Itoa(page) + var resp SortResp + _, err := d.request("/file/sort", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return nil, err + } + files = append(files, resp.Data.List...) + if page*size >= resp.Metadata.Total { + break + } + page++ + } + return files, nil +} + +func (d *QuarkOrUC) upPre(file model.FileStreamer, parentId string) (UpPreResp, error) { + now := time.Now() + data := base.Json{ + "ccp_hash_update": true, + "dir_name": "", + "file_name": file.GetName(), + "format_type": file.GetMimetype(), + "l_created_at": now.UnixMilli(), + "l_updated_at": now.UnixMilli(), + "pdir_fid": parentId, + "size": file.GetSize(), + //"same_path_reuse": true, + } + var resp UpPreResp + _, err := d.request("/file/upload/pre", http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, &resp) + return resp, err +} + +func (d *QuarkOrUC) upHash(md5, sha1, taskId string) (bool, error) { + data := base.Json{ + "md5": md5, + "sha1": sha1, + "task_id": taskId, + } + log.Debugf("hash: %+v", data) + var resp HashResp + _, err := d.request("/file/update/hash", http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, &resp) + return resp.Data.Finish, err +} + +func (d *QuarkOrUC) upPart(ctx context.Context, pre UpPreResp, mineType string, partNumber int, bytes []byte) (string, error) { + //func (driver QuarkOrUC) UpPart(pre UpPreResp, mineType string, partNumber int, bytes []byte, account *model.Account, md5Str, sha1Str string) (string, error) { + timeStr := time.Now().UTC().Format(http.TimeFormat) + data := base.Json{ + "auth_info": pre.Data.AuthInfo, + "auth_meta": fmt.Sprintf(`PUT + +%s +%s +x-oss-date:%s +x-oss-user-agent:aliyun-sdk-js/6.6.1 Chrome 98.0.4758.80 on Windows 10 64-bit +/%s/%s?partNumber=%d&uploadId=%s`, + mineType, timeStr, timeStr, pre.Data.Bucket, pre.Data.ObjKey, partNumber, pre.Data.UploadId), + "task_id": pre.Data.TaskId, + } + var resp UpAuthResp + _, err := d.request("/file/upload/auth", http.MethodPost, func(req *resty.Request) { + req.SetBody(data).SetContext(ctx) + }, &resp) + if err != nil { + return "", err + } + //if partNumber == 1 { + // finish, err := driver.UpHash(md5Str, sha1Str, pre.Data.TaskId, account) + // if err != nil { + // return "", err + // } + // if finish { + // return "finish", nil + // } + //} + u := fmt.Sprintf("https://%s.%s/%s", pre.Data.Bucket, pre.Data.UploadUrl[7:], pre.Data.ObjKey) + res, err := base.RestyClient.R().SetContext(ctx). + SetHeaders(map[string]string{ + "Authorization": resp.Data.AuthKey, + "Content-Type": mineType, + "Referer": "https://pan.quark.cn/", + "x-oss-date": timeStr, + "x-oss-user-agent": "aliyun-sdk-js/6.6.1 Chrome 98.0.4758.80 on Windows 10 64-bit", + }). + SetQueryParams(map[string]string{ + "partNumber": strconv.Itoa(partNumber), + "uploadId": pre.Data.UploadId, + }).SetBody(bytes).Put(u) + if res.StatusCode() != 200 { + return "", fmt.Errorf("up status: %d, error: %s", res.StatusCode(), res.String()) + } + return res.Header().Get("ETag"), nil +} + +func (d *QuarkOrUC) upCommit(pre UpPreResp, md5s []string) error { + timeStr := time.Now().UTC().Format(http.TimeFormat) + log.Debugf("md5s: %+v", md5s) + bodyBuilder := strings.Builder{} + bodyBuilder.WriteString(` + +`) + for i, m := range md5s { + bodyBuilder.WriteString(fmt.Sprintf(` +%d +%s + +`, i+1, m)) + } + bodyBuilder.WriteString("") + body := bodyBuilder.String() + m := md5.New() + m.Write([]byte(body)) + contentMd5 := base64.StdEncoding.EncodeToString(m.Sum(nil)) + callbackBytes, err := utils.Json.Marshal(pre.Data.Callback) + if err != nil { + return err + } + callbackBase64 := base64.StdEncoding.EncodeToString(callbackBytes) + data := base.Json{ + "auth_info": pre.Data.AuthInfo, + "auth_meta": fmt.Sprintf(`POST +%s +application/xml +%s +x-oss-callback:%s +x-oss-date:%s +x-oss-user-agent:aliyun-sdk-js/6.6.1 Chrome 98.0.4758.80 on Windows 10 64-bit +/%s/%s?uploadId=%s`, + contentMd5, timeStr, callbackBase64, timeStr, + pre.Data.Bucket, pre.Data.ObjKey, pre.Data.UploadId), + "task_id": pre.Data.TaskId, + } + log.Debugf("xml: %s", body) + log.Debugf("auth data: %+v", data) + var resp UpAuthResp + _, err = d.request("/file/upload/auth", http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, &resp) + if err != nil { + return err + } + u := fmt.Sprintf("https://%s.%s/%s", pre.Data.Bucket, pre.Data.UploadUrl[7:], pre.Data.ObjKey) + res, err := base.RestyClient.R(). + SetHeaders(map[string]string{ + "Authorization": resp.Data.AuthKey, + "Content-MD5": contentMd5, + "Content-Type": "application/xml", + "Referer": "https://pan.quark.cn/", + "x-oss-callback": callbackBase64, + "x-oss-date": timeStr, + "x-oss-user-agent": "aliyun-sdk-js/6.6.1 Chrome 98.0.4758.80 on Windows 10 64-bit", + }). + SetQueryParams(map[string]string{ + "uploadId": pre.Data.UploadId, + }).SetBody(body).Post(u) + if res.StatusCode() != 200 { + return fmt.Errorf("up status: %d, error: %s", res.StatusCode(), res.String()) + } + return nil +} + +func (d *QuarkOrUC) upFinish(pre UpPreResp) error { + data := base.Json{ + "obj_key": pre.Data.ObjKey, + "task_id": pre.Data.TaskId, + } + _, err := d.request("/file/upload/finish", http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + if err != nil { + return err + } + time.Sleep(time.Second) + return nil +} diff --git a/drivers/quark_uc_tv/driver.go b/drivers/quark_uc_tv/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..ff7ccf20f7a8128af3e62f42d0d58dcefb2645a3 --- /dev/null +++ b/drivers/quark_uc_tv/driver.go @@ -0,0 +1,174 @@ +package quark_uc_tv + +import ( + "context" + "fmt" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" +) + +type QuarkUCTV struct { + *QuarkUCTVCommon + model.Storage + Addition + config driver.Config + conf Conf +} + +func (d *QuarkUCTV) Config() driver.Config { + return d.config +} + +func (d *QuarkUCTV) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *QuarkUCTV) Init(ctx context.Context) error { + + if d.Addition.DeviceID == "" { + d.Addition.DeviceID = utils.GetMD5EncodeStr(time.Now().String()) + } + op.MustSaveDriverStorage(d) + + if d.QuarkUCTVCommon == nil { + d.QuarkUCTVCommon = &QuarkUCTVCommon{ + AccessToken: "", + } + } + ctx1, cancelFunc := context.WithTimeout(ctx, 5*time.Second) + defer cancelFunc() + if d.Addition.RefreshToken == "" { + if d.Addition.QueryToken == "" { + qrData, err := d.getLoginCode(ctx1) + if err != nil { + return err + } + // 展示二维码 + qrTemplate := ` + + ` + qrPage := fmt.Sprintf(qrTemplate, qrData) + return fmt.Errorf("need verify: \n%s", qrPage) + } else { + // 通过query token获取code -> refresh token + code, err := d.getCode(ctx1) + if err != nil { + return err + } + // 通过code获取refresh token + err = d.getRefreshTokenByTV(ctx1, code, false) + if err != nil { + return err + } + } + } + // 通过refresh token获取access token + if d.QuarkUCTVCommon.AccessToken == "" { + err := d.getRefreshTokenByTV(ctx1, d.Addition.RefreshToken, true) + if err != nil { + return err + } + } + + // 验证 access token 是否有效 + _, err := d.isLogin(ctx1) + if err != nil { + return err + } + return nil +} + +func (d *QuarkUCTV) Drop(ctx context.Context) error { + return nil +} + +func (d *QuarkUCTV) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files := make([]model.Obj, 0) + pageIndex := int64(0) + pageSize := int64(100) + for { + var filesData FilesData + _, err := d.request(ctx, "/file", "GET", func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "method": "list", + "parent_fid": dir.GetID(), + "order_by": "3", + "desc": "1", + "category": "", + "source": "", + "ex_source": "", + "list_all": "0", + "page_size": strconv.FormatInt(pageSize, 10), + "page_index": strconv.FormatInt(pageIndex, 10), + }) + }, &filesData) + if err != nil { + return nil, err + } + for i := range filesData.Data.Files { + files = append(files, &filesData.Data.Files[i]) + } + if pageIndex*pageSize >= filesData.Data.TotalCount { + break + } else { + pageIndex++ + } + } + return files, nil +} + +func (d *QuarkUCTV) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + files := &model.Link{} + var fileLink FileLink + _, err := d.request(ctx, "/file", "GET", func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "method": "download", + "group_by": "source", + "fid": file.GetID(), + "resolution": "low,normal,high,super,2k,4k", + "support": "dolby_vision", + }) + }, &fileLink) + if err != nil { + return nil, err + } + files.URL = fileLink.Data.DownloadURL + return files, nil +} + +func (d *QuarkUCTV) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + return nil, errs.NotImplement +} + +func (d *QuarkUCTV) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + return nil, errs.NotImplement +} + +func (d *QuarkUCTV) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + return nil, errs.NotImplement +} + +func (d *QuarkUCTV) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + return nil, errs.NotImplement +} + +func (d *QuarkUCTV) Remove(ctx context.Context, obj model.Obj) error { + return errs.NotImplement +} + +func (d *QuarkUCTV) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + return nil, errs.NotImplement +} + +type QuarkUCTVCommon struct { + AccessToken string +} + +var _ driver.Driver = (*QuarkUCTV)(nil) diff --git a/drivers/quark_uc_tv/meta.go b/drivers/quark_uc_tv/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..cf7e478566e3ef1287c1ace3efbda0d210745065 --- /dev/null +++ b/drivers/quark_uc_tv/meta.go @@ -0,0 +1,67 @@ +package quark_uc_tv + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + // Usually one of two + driver.RootID + // define other + RefreshToken string `json:"refresh_token" required:"false" default:""` + // 必要且影响登录,由签名决定 + DeviceID string `json:"device_id" required:"false" default:""` + // 登陆所用的数据 无需手动填写 + QueryToken string `json:"query_token" required:"false" default:"" help:"don't edit'"` +} + +type Conf struct { + api string + clientID string + signKey string + appVer string + channel string + codeApi string +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &QuarkUCTV{ + config: driver.Config{ + Name: "QuarkTV", + OnlyLocal: false, + DefaultRoot: "0", + NoOverwriteUpload: true, + NoUpload: true, + }, + conf: Conf{ + api: "https://open-api-drive.quark.cn", + clientID: "d3194e61504e493eb6222857bccfed94", + signKey: "kw2dvtd7p4t3pjl2d9ed9yc8yej8kw2d", + appVer: "1.5.6", + channel: "CP", + codeApi: "http://api.extscreen.com/quarkdrive", + }, + } + }) + op.RegisterDriver(func() driver.Driver { + return &QuarkUCTV{ + config: driver.Config{ + Name: "UCTV", + OnlyLocal: false, + DefaultRoot: "0", + NoOverwriteUpload: true, + NoUpload: true, + }, + conf: Conf{ + api: "https://open-api-drive.uc.cn", + clientID: "5acf882d27b74502b7040b0c65519aa7", + signKey: "l3srvtd7p42l0d0x1u8d7yc8ye9kki4d", + appVer: "1.6.5", + channel: "UCTVOFFICIALWEB", + codeApi: "http://api.extscreen.com/ucdrive", + }, + } + }) +} diff --git a/drivers/quark_uc_tv/types.go b/drivers/quark_uc_tv/types.go new file mode 100644 index 0000000000000000000000000000000000000000..fb35b8b2d6dd67382f274e118b8f7fb947319ece --- /dev/null +++ b/drivers/quark_uc_tv/types.go @@ -0,0 +1,102 @@ +package quark_uc_tv + +import ( + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "time" +) + +type Resp struct { + CommonRsp + Errno int `json:"errno"` + ErrorInfo string `json:"error_info"` +} + +type CommonRsp struct { + Status int `json:"status"` + ReqID string `json:"req_id"` +} + +type RefreshTokenAuthResp struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + Status int `json:"status"` + Errno int `json:"errno"` + ErrorInfo string `json:"error_info"` + ReqID string `json:"req_id"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + } `json:"data"` +} +type Files struct { + Fid string `json:"fid"` + ParentFid string `json:"parent_fid"` + Category int `json:"category"` + Filename string `json:"filename"` + Size int64 `json:"size"` + FileType string `json:"file_type"` + SubItems int `json:"sub_items,omitempty"` + Isdir int `json:"isdir"` + Duration int `json:"duration"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + IsBackup int `json:"is_backup"` + ThumbnailURL string `json:"thumbnail_url,omitempty"` +} + +func (f *Files) GetSize() int64 { + return f.Size +} + +func (f *Files) GetName() string { + return f.Filename +} + +func (f *Files) ModTime() time.Time { + //return time.Unix(f.UpdatedAt, 0) + return time.Unix(0, f.UpdatedAt*int64(time.Millisecond)) +} + +func (f *Files) CreateTime() time.Time { + //return time.Unix(f.CreatedAt, 0) + return time.Unix(0, f.CreatedAt*int64(time.Millisecond)) +} + +func (f *Files) IsDir() bool { + return f.Isdir == 1 +} + +func (f *Files) GetHash() utils.HashInfo { + return utils.HashInfo{} +} + +func (f *Files) GetID() string { + return f.Fid +} + +func (f *Files) GetPath() string { + return "" +} + +var _ model.Obj = (*Files)(nil) + +type FilesData struct { + CommonRsp + Data struct { + TotalCount int64 `json:"total_count"` + Files []Files `json:"files"` + } `json:"data"` +} + +type FileLink struct { + CommonRsp + Data struct { + Fid string `json:"fid"` + FileName string `json:"file_name"` + Size int64 `json:"size"` + DownloadURL string `json:"download_url"` + } `json:"data"` +} diff --git a/drivers/quark_uc_tv/util.go b/drivers/quark_uc_tv/util.go new file mode 100644 index 0000000000000000000000000000000000000000..fefbb0361fb113a91e1501a1f0cf85153e4140aa --- /dev/null +++ b/drivers/quark_uc_tv/util.go @@ -0,0 +1,211 @@ +package quark_uc_tv + +import ( + "context" + "crypto/md5" + "crypto/sha256" + "encoding/hex" + "errors" + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + "net/http" + "strconv" + "time" +) + +const ( + UserAgent = "Mozilla/5.0 (Linux; U; Android 13; zh-cn; M2004J7AC Build/UKQ1.231108.001) AppleWebKit/533.1 (KHTML, like Gecko) Mobile Safari/533.1" + DeviceBrand = "Xiaomi" + Platform = "tv" + DeviceName = "M2004J7AC" + DeviceModel = "M2004J7AC" + BuildDevice = "M2004J7AC" + BuildProduct = "M2004J7AC" + DeviceGpu = "Adreno (TM) 550" + ActivityRect = "{}" +) + +func (d *QuarkUCTV) request(ctx context.Context, pathname string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + u := d.conf.api + pathname + tm, token, reqID := d.generateReqSign(method, pathname, d.conf.signKey) + req := base.RestyClient.R() + req.SetContext(ctx) + req.SetHeaders(map[string]string{ + "Accept": "application/json, text/plain, */*", + "User-Agent": UserAgent, + "x-pan-tm": tm, + "x-pan-token": token, + "x-pan-client-id": d.conf.clientID, + }) + req.SetQueryParams(map[string]string{ + "req_id": reqID, + "access_token": d.QuarkUCTVCommon.AccessToken, + "app_ver": d.conf.appVer, + "device_id": d.Addition.DeviceID, + "device_brand": DeviceBrand, + "platform": Platform, + "device_name": DeviceName, + "device_model": DeviceModel, + "build_device": BuildDevice, + "build_product": BuildProduct, + "device_gpu": DeviceGpu, + "activity_rect": ActivityRect, + "channel": d.conf.channel, + }) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e Resp + req.SetError(&e) + res, err := req.Execute(method, u) + if err != nil { + return nil, err + } + // 判断 是否需要 刷新 access_token + if e.Status == -1 && e.Errno == 10001 { + // token 过期 + err = d.getRefreshTokenByTV(ctx, d.Addition.RefreshToken, true) + if err != nil { + return nil, err + } + ctx1, cancelFunc := context.WithTimeout(ctx, 10*time.Second) + defer cancelFunc() + return d.request(ctx1, pathname, method, callback, resp) + } + + if e.Status >= 400 || e.Errno != 0 { + return nil, errors.New(e.ErrorInfo) + } + return res.Body(), nil +} + +func (d *QuarkUCTV) getLoginCode(ctx context.Context) (string, error) { + // 获取登录二维码 + pathname := "/oauth/authorize" + var resp struct { + CommonRsp + QrData string `json:"qr_data"` + QueryToken string `json:"query_token"` + } + _, err := d.request(ctx, pathname, "GET", func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "auth_type": "code", + "client_id": d.conf.clientID, + "scope": "netdisk", + "qrcode": "1", + "qr_width": "460", + "qr_height": "460", + }) + }, &resp) + if err != nil { + return "", err + } + // 保存query_token 用于后续登录 + if resp.QueryToken != "" { + d.Addition.QueryToken = resp.QueryToken + op.MustSaveDriverStorage(d) + } + return resp.QrData, nil +} + +func (d *QuarkUCTV) getCode(ctx context.Context) (string, error) { + // 通过query token获取code + pathname := "/oauth/code" + var resp struct { + CommonRsp + Code string `json:"code"` + } + _, err := d.request(ctx, pathname, "GET", func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "client_id": d.conf.clientID, + "scope": "netdisk", + "query_token": d.Addition.QueryToken, + }) + }, &resp) + if err != nil { + return "", err + } + return resp.Code, nil +} + +func (d *QuarkUCTV) getRefreshTokenByTV(ctx context.Context, code string, isRefresh bool) error { + pathname := "/token" + _, _, reqID := d.generateReqSign("POST", pathname, d.conf.signKey) + u := d.conf.codeApi + pathname + var resp RefreshTokenAuthResp + body := map[string]string{ + "req_id": reqID, + "app_ver": d.conf.appVer, + "device_id": d.Addition.DeviceID, + "device_brand": DeviceBrand, + "platform": Platform, + "device_name": DeviceName, + "device_model": DeviceModel, + "build_device": BuildDevice, + "build_product": BuildProduct, + "device_gpu": DeviceGpu, + "activity_rect": ActivityRect, + "channel": d.conf.channel, + } + if isRefresh { + body["refresh_token"] = code + } else { + body["code"] = code + } + + _, err := base.RestyClient.R(). + SetHeader("Content-Type", "application/json"). + SetBody(body). + SetResult(&resp). + SetContext(ctx). + Post(u) + if err != nil { + return err + } + if resp.Code != 200 { + return errors.New(resp.Message) + } + if resp.Data.RefreshToken != "" { + d.Addition.RefreshToken = resp.Data.RefreshToken + op.MustSaveDriverStorage(d) + d.QuarkUCTVCommon.AccessToken = resp.Data.AccessToken + } else { + return errors.New("refresh token is empty") + } + return nil +} + +func (d *QuarkUCTV) isLogin(ctx context.Context) (bool, error) { + _, err := d.request(ctx, "/user", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "method": "user_info", + }) + }, nil) + return err == nil, err +} + +func (d *QuarkUCTV) generateReqSign(method string, pathname string, key string) (string, string, string) { + //timestamp 13位时间戳 + timestamp := strconv.FormatInt(time.Now().UnixNano()/int64(time.Millisecond), 10) + deviceID := d.Addition.DeviceID + if deviceID == "" { + deviceID = utils.GetMD5EncodeStr(timestamp) + d.Addition.DeviceID = deviceID + op.MustSaveDriverStorage(d) + } + // 生成req_id + reqID := md5.Sum([]byte(deviceID + timestamp)) + reqIDHex := hex.EncodeToString(reqID[:]) + + // 生成x_pan_token + tokenData := method + "&" + pathname + "&" + timestamp + "&" + key + xPanToken := sha256.Sum256([]byte(tokenData)) + xPanTokenHex := hex.EncodeToString(xPanToken[:]) + + return timestamp, xPanTokenHex, reqIDHex +} diff --git a/drivers/quqi/driver.go b/drivers/quqi/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..51e54981a186cf6d386c69eda7ddc957cfdb78fa --- /dev/null +++ b/drivers/quqi/driver.go @@ -0,0 +1,437 @@ +package quqi + +import ( + "bytes" + "context" + "io" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/pkg/utils/random" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +type Quqi struct { + model.Storage + Addition + Cookie string // Cookie + GroupID string // 私人云群组ID + ClientID string // 随机生成客户端ID 经过测试,部分接口调用若不携带client id会出现错误 +} + +func (d *Quqi) Config() driver.Config { + return config +} + +func (d *Quqi) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Quqi) Init(ctx context.Context) error { + // 登录 + if err := d.login(); err != nil { + return err + } + + // 生成随机client id (与网页端生成逻辑一致) + d.ClientID = "quqipc_" + random.String(10) + + // 获取私人云ID (暂时仅获取私人云) + groupResp := &GroupRes{} + if _, err := d.request("group.quqi.com", "/v1/group/list", resty.MethodGet, nil, groupResp); err != nil { + return err + } + for _, groupInfo := range groupResp.Data { + if groupInfo == nil { + continue + } + if groupInfo.Type == 2 { + d.GroupID = strconv.Itoa(groupInfo.ID) + break + } + } + if d.GroupID == "" { + return errs.StorageNotFound + } + + return nil +} + +func (d *Quqi) Drop(ctx context.Context) error { + return nil +} + +func (d *Quqi) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + var ( + listResp = &ListRes{} + files []model.Obj + ) + + if _, err := d.request("", "/api/dir/ls", resty.MethodPost, func(req *resty.Request) { + req.SetFormData(map[string]string{ + "quqi_id": d.GroupID, + "tree_id": "1", + "node_id": dir.GetID(), + "client_id": d.ClientID, + }) + }, listResp); err != nil { + return nil, err + } + + if listResp.Data == nil { + return nil, nil + } + + // dirs + for _, dirInfo := range listResp.Data.Dir { + if dirInfo == nil { + continue + } + files = append(files, &model.Object{ + ID: strconv.FormatInt(dirInfo.NodeID, 10), + Name: dirInfo.Name, + Modified: time.Unix(dirInfo.UpdateTime, 0), + Ctime: time.Unix(dirInfo.AddTime, 0), + IsFolder: true, + }) + } + + // files + for _, fileInfo := range listResp.Data.File { + if fileInfo == nil { + continue + } + if fileInfo.EXT != "" { + fileInfo.Name = strings.Join([]string{fileInfo.Name, fileInfo.EXT}, ".") + } + + files = append(files, &model.Object{ + ID: strconv.FormatInt(fileInfo.NodeID, 10), + Name: fileInfo.Name, + Size: fileInfo.Size, + Modified: time.Unix(fileInfo.UpdateTime, 0), + Ctime: time.Unix(fileInfo.AddTime, 0), + }) + } + + return files, nil +} + +func (d *Quqi) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if d.CDN { + link, err := d.linkFromCDN(file.GetID()) + if err != nil { + log.Warn(err) + } else { + return link, nil + } + } + + link, err := d.linkFromPreview(file.GetID()) + if err != nil { + log.Warn(err) + } else { + return link, nil + } + + link, err = d.linkFromDownload(file.GetID()) + if err != nil { + return nil, err + } + return link, nil +} + +func (d *Quqi) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + var ( + makeDirRes = &MakeDirRes{} + timeNow = time.Now() + ) + + if _, err := d.request("", "/api/dir/mkDir", resty.MethodPost, func(req *resty.Request) { + req.SetFormData(map[string]string{ + "quqi_id": d.GroupID, + "tree_id": "1", + "parent_id": parentDir.GetID(), + "name": dirName, + "client_id": d.ClientID, + }) + }, makeDirRes); err != nil { + return nil, err + } + + return &model.Object{ + ID: strconv.FormatInt(makeDirRes.Data.NodeID, 10), + Name: dirName, + Modified: timeNow, + Ctime: timeNow, + IsFolder: true, + }, nil +} + +func (d *Quqi) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + var moveRes = &MoveRes{} + + if _, err := d.request("", "/api/dir/mvDir", resty.MethodPost, func(req *resty.Request) { + req.SetFormData(map[string]string{ + "quqi_id": d.GroupID, + "tree_id": "1", + "node_id": dstDir.GetID(), + "source_quqi_id": d.GroupID, + "source_tree_id": "1", + "source_node_id": srcObj.GetID(), + "client_id": d.ClientID, + }) + }, moveRes); err != nil { + return nil, err + } + + return &model.Object{ + ID: strconv.FormatInt(moveRes.Data.NodeID, 10), + Name: moveRes.Data.NodeName, + Size: srcObj.GetSize(), + Modified: time.Now(), + Ctime: srcObj.CreateTime(), + IsFolder: srcObj.IsDir(), + }, nil +} + +func (d *Quqi) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + var realName = newName + + if !srcObj.IsDir() { + srcExt, newExt := utils.Ext(srcObj.GetName()), utils.Ext(newName) + + // 曲奇网盘的文件名称由文件名和扩展名组成,若存在扩展名,则重命名时仅支持更改文件名,扩展名在曲奇服务端保留 + if srcExt != "" && srcExt == newExt { + parts := strings.Split(newName, ".") + if len(parts) > 1 { + realName = strings.Join(parts[:len(parts)-1], ".") + } + } + } + + if _, err := d.request("", "/api/dir/renameDir", resty.MethodPost, func(req *resty.Request) { + req.SetFormData(map[string]string{ + "quqi_id": d.GroupID, + "tree_id": "1", + "node_id": srcObj.GetID(), + "rename": realName, + "client_id": d.ClientID, + }) + }, nil); err != nil { + return nil, err + } + + return &model.Object{ + ID: srcObj.GetID(), + Name: newName, + Size: srcObj.GetSize(), + Modified: time.Now(), + Ctime: srcObj.CreateTime(), + IsFolder: srcObj.IsDir(), + }, nil +} + +func (d *Quqi) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + // 无法从曲奇接口响应中直接获取复制后的文件信息 + if _, err := d.request("", "/api/node/copy", resty.MethodPost, func(req *resty.Request) { + req.SetFormData(map[string]string{ + "quqi_id": d.GroupID, + "tree_id": "1", + "node_id": dstDir.GetID(), + "source_quqi_id": d.GroupID, + "source_tree_id": "1", + "source_node_id": srcObj.GetID(), + "client_id": d.ClientID, + }) + }, nil); err != nil { + return nil, err + } + + return nil, nil +} + +func (d *Quqi) Remove(ctx context.Context, obj model.Obj) error { + // 暂时不做直接删除,默认都放到回收站。直接删除方法:先调用删除接口放入回收站,在通过回收站接口删除文件 + if _, err := d.request("", "/api/node/del", resty.MethodPost, func(req *resty.Request) { + req.SetFormData(map[string]string{ + "quqi_id": d.GroupID, + "tree_id": "1", + "node_id": obj.GetID(), + "client_id": d.ClientID, + }) + }, nil); err != nil { + return err + } + + return nil +} + +func (d *Quqi) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + // base info + sizeStr := strconv.FormatInt(stream.GetSize(), 10) + f, err := stream.CacheFullInTempFile() + if err != nil { + return nil, err + } + md5, err := utils.HashFile(utils.MD5, f) + if err != nil { + return nil, err + } + sha, err := utils.HashFile(utils.SHA256, f) + if err != nil { + return nil, err + } + // init upload + var uploadInitResp UploadInitResp + _, err = d.request("", "/api/upload/v1/file/init", resty.MethodPost, func(req *resty.Request) { + req.SetFormData(map[string]string{ + "quqi_id": d.GroupID, + "tree_id": "1", + "parent_id": dstDir.GetID(), + "size": sizeStr, + "file_name": stream.GetName(), + "md5": md5, + "sha": sha, + "is_slice": "true", + "client_id": d.ClientID, + }) + }, &uploadInitResp) + if err != nil { + return nil, err + } + // check exist + // if the file already exists in Quqi server, there is no need to actually upload it + if uploadInitResp.Data.Exist { + // the file name returned by Quqi does not include the extension name + nodeName, nodeExt := uploadInitResp.Data.NodeName, rawExt(stream.GetName()) + if nodeExt != "" { + nodeName = nodeName + "." + nodeExt + } + return &model.Object{ + ID: strconv.FormatInt(uploadInitResp.Data.NodeID, 10), + Name: nodeName, + Size: stream.GetSize(), + Modified: stream.ModTime(), + Ctime: stream.CreateTime(), + }, nil + } + // listParts + _, err = d.request("upload.quqi.com:20807", "/upload/v1/listParts", resty.MethodPost, func(req *resty.Request) { + req.SetFormData(map[string]string{ + "token": uploadInitResp.Data.Token, + "task_id": uploadInitResp.Data.TaskID, + "client_id": d.ClientID, + }) + }, nil) + if err != nil { + return nil, err + } + // get temp key + var tempKeyResp TempKeyResp + _, err = d.request("upload.quqi.com:20807", "/upload/v1/tempKey", resty.MethodGet, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "token": uploadInitResp.Data.Token, + "task_id": uploadInitResp.Data.TaskID, + }) + }, &tempKeyResp) + if err != nil { + return nil, err + } + // upload + // u, err := url.Parse(fmt.Sprintf("https://%s.cos.ap-shanghai.myqcloud.com", uploadInitResp.Data.Bucket)) + // b := &cos.BaseURL{BucketURL: u} + // client := cos.NewClient(b, &http.Client{ + // Transport: &cos.CredentialTransport{ + // Credential: cos.NewTokenCredential(tempKeyResp.Data.Credentials.TmpSecretID, tempKeyResp.Data.Credentials.TmpSecretKey, tempKeyResp.Data.Credentials.SessionToken), + // }, + // }) + // partSize := int64(1024 * 1024 * 2) + // partCount := (stream.GetSize() + partSize - 1) / partSize + // for i := 1; i <= int(partCount); i++ { + // length := partSize + // if i == int(partCount) { + // length = stream.GetSize() - (int64(i)-1)*partSize + // } + // _, err := client.Object.UploadPart( + // ctx, uploadInitResp.Data.Key, uploadInitResp.Data.UploadID, i, io.LimitReader(f, partSize), &cos.ObjectUploadPartOptions{ + // ContentLength: length, + // }, + // ) + // if err != nil { + // return nil, err + // } + // } + + cfg := &aws.Config{ + Credentials: credentials.NewStaticCredentials(tempKeyResp.Data.Credentials.TmpSecretID, tempKeyResp.Data.Credentials.TmpSecretKey, tempKeyResp.Data.Credentials.SessionToken), + Region: aws.String("ap-shanghai"), + Endpoint: aws.String("cos.ap-shanghai.myqcloud.com"), + } + s, err := session.NewSession(cfg) + if err != nil { + return nil, err + } + uploader := s3manager.NewUploader(s) + buf := make([]byte, 1024*1024*2) + for partNumber := int64(1); ; partNumber++ { + n, err := io.ReadFull(f, buf) + if err != nil && err != io.ErrUnexpectedEOF { + if err == io.EOF { + break + } + return nil, err + } + _, err = uploader.S3.UploadPartWithContext(ctx, &s3.UploadPartInput{ + UploadId: &uploadInitResp.Data.UploadID, + Key: &uploadInitResp.Data.Key, + Bucket: &uploadInitResp.Data.Bucket, + PartNumber: aws.Int64(partNumber), + Body: bytes.NewReader(buf[:n]), + }) + if err != nil { + return nil, err + } + } + // finish upload + var uploadFinishResp UploadFinishResp + _, err = d.request("", "/api/upload/v1/file/finish", resty.MethodPost, func(req *resty.Request) { + req.SetFormData(map[string]string{ + "token": uploadInitResp.Data.Token, + "task_id": uploadInitResp.Data.TaskID, + "client_id": d.ClientID, + }) + }, &uploadFinishResp) + if err != nil { + return nil, err + } + // the file name returned by Quqi does not include the extension name + nodeName, nodeExt := uploadFinishResp.Data.NodeName, rawExt(stream.GetName()) + if nodeExt != "" { + nodeName = nodeName + "." + nodeExt + } + return &model.Object{ + ID: strconv.FormatInt(uploadFinishResp.Data.NodeID, 10), + Name: nodeName, + Size: stream.GetSize(), + Modified: stream.ModTime(), + Ctime: stream.CreateTime(), + }, nil +} + +//func (d *Template) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*Quqi)(nil) diff --git a/drivers/quqi/meta.go b/drivers/quqi/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..aaaa0a19444d8a82e7cc17823fffc38c64dea5f3 --- /dev/null +++ b/drivers/quqi/meta.go @@ -0,0 +1,28 @@ +package quqi + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootID + Phone string `json:"phone"` + Password string `json:"password"` + Cookie string `json:"cookie" help:"Cookie can be used on multiple clients at the same time"` + CDN bool `json:"cdn" help:"If you enable this option, the download speed can be increased, but there will be some performance loss"` +} + +var config = driver.Config{ + Name: "Quqi", + OnlyLocal: true, + LocalSort: true, + //NoUpload: true, + DefaultRoot: "0", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Quqi{} + }) +} diff --git a/drivers/quqi/types.go b/drivers/quqi/types.go new file mode 100644 index 0000000000000000000000000000000000000000..3255736153229bdcfc1d725d5dc0197747177098 --- /dev/null +++ b/drivers/quqi/types.go @@ -0,0 +1,197 @@ +package quqi + +type BaseReqQuery struct { + ID string `json:"quqiid"` +} + +type BaseReq struct { + GroupID string `json:"quqi_id"` +} + +type BaseRes struct { + //Data interface{} `json:"data"` + Code int `json:"err"` + Message string `json:"msg"` +} + +type GroupRes struct { + BaseRes + Data []*Group `json:"data"` +} + +type ListRes struct { + BaseRes + Data *List `json:"data"` +} + +type GetDocRes struct { + BaseRes + Data struct { + OriginPath string `json:"origin_path"` + } `json:"data"` +} + +type GetDownloadResp struct { + BaseRes + Data struct { + Url string `json:"url"` + } `json:"data"` +} + +type MakeDirRes struct { + BaseRes + Data struct { + IsRoot bool `json:"is_root"` + NodeID int64 `json:"node_id"` + ParentID int64 `json:"parent_id"` + } `json:"data"` +} + +type MoveRes struct { + BaseRes + Data struct { + NodeChildNum int64 `json:"node_child_num"` + NodeID int64 `json:"node_id"` + NodeName string `json:"node_name"` + ParentID int64 `json:"parent_id"` + GroupID int64 `json:"quqi_id"` + TreeID int64 `json:"tree_id"` + } `json:"data"` +} + +type RenameRes struct { + BaseRes + Data struct { + NodeID int64 `json:"node_id"` + GroupID int64 `json:"quqi_id"` + Rename string `json:"rename"` + TreeID int64 `json:"tree_id"` + UpdateTime int64 `json:"updatetime"` + } `json:"data"` +} + +type CopyRes struct { + BaseRes +} + +type RemoveRes struct { + BaseRes +} + +type Group struct { + ID int `json:"quqi_id"` + Type int `json:"type"` + Name string `json:"name"` + IsAdministrator int `json:"is_administrator"` + Role int `json:"role"` + Avatar string `json:"avatar_url"` + IsStick int `json:"is_stick"` + Nickname string `json:"nickname"` + Status int `json:"status"` +} + +type List struct { + ListDir + Dir []*ListDir `json:"dir"` + File []*ListFile `json:"file"` +} + +type ListItem struct { + AddTime int64 `json:"add_time"` + IsDir int `json:"is_dir"` + IsExpand int `json:"is_expand"` + IsFinalize int `json:"is_finalize"` + LastEditorName string `json:"last_editor_name"` + Name string `json:"name"` + NodeID int64 `json:"nid"` + ParentID int64 `json:"parent_id"` + Permission int `json:"permission"` + TreeID int64 `json:"tid"` + UpdateCNT int64 `json:"update_cnt"` + UpdateTime int64 `json:"update_time"` +} + +type ListDir struct { + ListItem + ChildDocNum int64 `json:"child_doc_num"` + DirDetail string `json:"dir_detail"` + DirType int `json:"dir_type"` +} + +type ListFile struct { + ListItem + BroadDocType string `json:"broad_doc_type"` + CanDisplay bool `json:"can_display"` + Detail string `json:"detail"` + EXT string `json:"ext"` + Filetype string `json:"filetype"` + HasMobileThumbnail bool `json:"has_mobile_thumbnail"` + HasThumbnail bool `json:"has_thumbnail"` + Size int64 `json:"size"` + Version int `json:"version"` +} + +type UploadInitResp struct { + Data struct { + Bucket string `json:"bucket"` + Exist bool `json:"exist"` + Key string `json:"key"` + TaskID string `json:"task_id"` + Token string `json:"token"` + UploadID string `json:"upload_id"` + URL string `json:"url"` + NodeID int64 `json:"node_id"` + NodeName string `json:"node_name"` + ParentID int64 `json:"parent_id"` + } `json:"data"` + Err int `json:"err"` + Msg string `json:"msg"` +} + +type TempKeyResp struct { + Err int `json:"err"` + Msg string `json:"msg"` + Data struct { + ExpiredTime int `json:"expiredTime"` + Expiration string `json:"expiration"` + Credentials struct { + SessionToken string `json:"sessionToken"` + TmpSecretID string `json:"tmpSecretId"` + TmpSecretKey string `json:"tmpSecretKey"` + } `json:"credentials"` + RequestID string `json:"requestId"` + StartTime int `json:"startTime"` + } `json:"data"` +} + +type UploadFinishResp struct { + Data struct { + NodeID int64 `json:"node_id"` + NodeName string `json:"node_name"` + ParentID int64 `json:"parent_id"` + QuqiID int64 `json:"quqi_id"` + TreeID int64 `json:"tree_id"` + } `json:"data"` + Err int `json:"err"` + Msg string `json:"msg"` +} + +type UrlExchangeResp struct { + BaseRes + Data struct { + Name string `json:"name"` + Mime string `json:"mime"` + Size int64 `json:"size"` + DownloadType int `json:"download_type"` + ChannelType int `json:"channel_type"` + ChannelID int `json:"channel_id"` + Url string `json:"url"` + ExpiredTime int64 `json:"expired_time"` + IsEncrypted bool `json:"is_encrypted"` + EncryptedSize int64 `json:"encrypted_size"` + EncryptedAlg string `json:"encrypted_alg"` + EncryptedKey string `json:"encrypted_key"` + PassportID int64 `json:"passport_id"` + RequestExpiredTime int64 `json:"request_expired_time"` + } `json:"data"` +} diff --git a/drivers/quqi/util.go b/drivers/quqi/util.go new file mode 100644 index 0000000000000000000000000000000000000000..c025f6ee8af3ec37fb53222a4bd4135fc5966782 --- /dev/null +++ b/drivers/quqi/util.go @@ -0,0 +1,316 @@ +package quqi + +import ( + "bufio" + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "net/http" + "net/url" + stdpath "path" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + "github.com/minio/sio" +) + +// do others that not defined in Driver interface +func (d *Quqi) request(host string, path string, method string, callback base.ReqCallback, resp interface{}) (*resty.Response, error) { + var ( + reqUrl = url.URL{ + Scheme: "https", + Host: "quqi.com", + Path: path, + } + req = base.RestyClient.R() + result BaseRes + ) + + if host != "" { + reqUrl.Host = host + } + req.SetHeaders(map[string]string{ + "Origin": "https://quqi.com", + "Cookie": d.Cookie, + }) + + if d.GroupID != "" { + req.SetQueryParam("quqiid", d.GroupID) + } + + if callback != nil { + callback(req) + } + + res, err := req.Execute(method, reqUrl.String()) + if err != nil { + return nil, err + } + // resty.Request.SetResult cannot parse result correctly sometimes + err = utils.Json.Unmarshal(res.Body(), &result) + if err != nil { + return nil, err + } + if result.Code != 0 { + return nil, errors.New(result.Message) + } + if resp != nil { + err = utils.Json.Unmarshal(res.Body(), resp) + if err != nil { + return nil, err + } + } + return res, nil +} + +func (d *Quqi) login() error { + if d.Addition.Cookie != "" { + d.Cookie = d.Addition.Cookie + } + if d.checkLogin() { + return nil + } + if d.Cookie != "" { + return errors.New("cookie is invalid") + } + if d.Phone == "" { + return errors.New("phone number is empty") + } + if d.Password == "" { + return errs.EmptyPassword + } + + resp, err := d.request("", "/auth/person/v2/login/password", resty.MethodPost, func(req *resty.Request) { + req.SetFormData(map[string]string{ + "phone": d.Phone, + "password": base64.StdEncoding.EncodeToString([]byte(d.Password)), + }) + }, nil) + if err != nil { + return err + } + + var cookies []string + for _, cookie := range resp.RawResponse.Cookies() { + cookies = append(cookies, fmt.Sprintf("%s=%s", cookie.Name, cookie.Value)) + } + d.Cookie = strings.Join(cookies, ";") + + return nil +} + +func (d *Quqi) checkLogin() bool { + if _, err := d.request("", "/auth/account/baseInfo", resty.MethodGet, nil, nil); err != nil { + return false + } + return true +} + +// rawExt 保留扩展名大小写 +func rawExt(name string) string { + ext := stdpath.Ext(name) + if strings.HasPrefix(ext, ".") { + ext = ext[1:] + } + + return ext +} + +// decryptKey 获取密码 +func decryptKey(encodeKey string) []byte { + // 移除非法字符 + u := strings.ReplaceAll(encodeKey, "[^A-Za-z0-9+\\/]", "") + + // 计算输出字节数组的长度 + o := len(u) + a := 32 + + // 创建输出字节数组 + c := make([]byte, a) + + // 编码循环 + s := uint32(0) // 累加器 + f := 0 // 输出数组索引 + for l := 0; l < o; l++ { + r := l & 3 // 取模4,得到当前字符在四字节块中的位置 + i := u[l] // 当前字符的ASCII码 + + // 编码当前字符 + switch { + case i >= 65 && i < 91: // 大写字母 + s |= uint32(i-65) << uint32(6*(3-r)) + case i >= 97 && i < 123: // 小写字母 + s |= uint32(i-71) << uint32(6*(3-r)) + case i >= 48 && i < 58: // 数字 + s |= uint32(i+4) << uint32(6*(3-r)) + case i == 43: // 加号 + s |= uint32(62) << uint32(6*(3-r)) + case i == 47: // 斜杠 + s |= uint32(63) << uint32(6*(3-r)) + } + + // 如果累加器已经包含了四个字符,或者是最后一个字符,则写入输出数组 + if r == 3 || l == o-1 { + for e := 0; e < 3 && f < a; e, f = e+1, f+1 { + c[f] = byte(s >> (16 >> e & 24) & 255) + } + s = 0 + } + } + + return c +} + +func (d *Quqi) linkFromPreview(id string) (*model.Link, error) { + var getDocResp GetDocRes + if _, err := d.request("", "/api/doc/getDoc", resty.MethodPost, func(req *resty.Request) { + req.SetFormData(map[string]string{ + "quqi_id": d.GroupID, + "tree_id": "1", + "node_id": id, + "client_id": d.ClientID, + }) + }, &getDocResp); err != nil { + return nil, err + } + if getDocResp.Data.OriginPath == "" { + return nil, errors.New("cannot get link from preview") + } + return &model.Link{ + URL: getDocResp.Data.OriginPath, + Header: http.Header{ + "Origin": []string{"https://quqi.com"}, + "Cookie": []string{d.Cookie}, + }, + }, nil +} + +func (d *Quqi) linkFromDownload(id string) (*model.Link, error) { + var getDownloadResp GetDownloadResp + if _, err := d.request("", "/api/doc/getDownload", resty.MethodGet, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "quqi_id": d.GroupID, + "tree_id": "1", + "node_id": id, + "url_type": "undefined", + "entry_type": "undefined", + "client_id": d.ClientID, + "no_redirect": "1", + }) + }, &getDownloadResp); err != nil { + return nil, err + } + if getDownloadResp.Data.Url == "" { + return nil, errors.New("cannot get link from download") + } + + return &model.Link{ + URL: getDownloadResp.Data.Url, + Header: http.Header{ + "Origin": []string{"https://quqi.com"}, + "Cookie": []string{d.Cookie}, + }, + }, nil +} + +func (d *Quqi) linkFromCDN(id string) (*model.Link, error) { + downloadLink, err := d.linkFromDownload(id) + if err != nil { + return nil, err + } + + var urlExchangeResp UrlExchangeResp + if _, err = d.request("api.quqi.com", "/preview/downloadInfo/url/exchange", resty.MethodGet, func(req *resty.Request) { + req.SetQueryParam("url", downloadLink.URL) + }, &urlExchangeResp); err != nil { + return nil, err + } + if urlExchangeResp.Data.Url == "" { + return nil, errors.New("cannot get link from cdn") + } + + // 假设存在未加密的情况 + if !urlExchangeResp.Data.IsEncrypted { + return &model.Link{ + URL: urlExchangeResp.Data.Url, + Header: http.Header{ + "Origin": []string{"https://quqi.com"}, + "Cookie": []string{d.Cookie}, + }, + }, nil + } + + // 根据sio(https://github.com/minio/sio/blob/master/DARE.md)描述及实际测试,得出以下结论: + // 1. 加密后大小(encrypted_size)-原始文件大小(size) = 加密包的头大小+身份验证标识 = (16+16) * N -> N为加密包的数量 + // 2. 原始文件大小(size)+64*1024-1 / (64*1024) = N -> 每个包的有效负载为64K + remoteClosers := utils.EmptyClosers() + payloadSize := int64(1 << 16) + expiration := time.Until(time.Unix(urlExchangeResp.Data.ExpiredTime, 0)) + resultRangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + encryptedOffset := httpRange.Start / payloadSize * (payloadSize + 32) + decryptedOffset := httpRange.Start % payloadSize + encryptedLength := (httpRange.Length+httpRange.Start+payloadSize-1)/payloadSize*(payloadSize+32) - encryptedOffset + if httpRange.Length < 0 { + encryptedLength = httpRange.Length + } else { + if httpRange.Length+httpRange.Start >= urlExchangeResp.Data.Size || encryptedLength+encryptedOffset >= urlExchangeResp.Data.EncryptedSize { + encryptedLength = -1 + } + } + //log.Debugf("size: %d\tencrypted_size: %d", urlExchangeResp.Data.Size, urlExchangeResp.Data.EncryptedSize) + //log.Debugf("http range offset: %d, length: %d", httpRange.Start, httpRange.Length) + //log.Debugf("encrypted offset: %d, length: %d, decrypted offset: %d", encryptedOffset, encryptedLength, decryptedOffset) + + rrc, err := stream.GetRangeReadCloserFromLink(urlExchangeResp.Data.EncryptedSize, &model.Link{ + URL: urlExchangeResp.Data.Url, + Header: http.Header{ + "Origin": []string{"https://quqi.com"}, + "Cookie": []string{d.Cookie}, + }, + }) + if err != nil { + return nil, err + } + + rc, err := rrc.RangeRead(ctx, http_range.Range{Start: encryptedOffset, Length: encryptedLength}) + remoteClosers.AddClosers(rrc.GetClosers()) + if err != nil { + return nil, err + } + + decryptReader, err := sio.DecryptReader(rc, sio.Config{ + MinVersion: sio.Version10, + MaxVersion: sio.Version20, + CipherSuites: []byte{sio.CHACHA20_POLY1305, sio.AES_256_GCM}, + Key: decryptKey(urlExchangeResp.Data.EncryptedKey), + SequenceNumber: uint32(httpRange.Start / payloadSize), + }) + if err != nil { + return nil, err + } + bufferReader := bufio.NewReader(decryptReader) + bufferReader.Discard(int(decryptedOffset)) + + return utils.NewReadCloser(bufferReader, func() error { + return nil + }), nil + } + + return &model.Link{ + Header: http.Header{ + "Origin": []string{"https://quqi.com"}, + "Cookie": []string{d.Cookie}, + }, + RangeReadCloser: &model.RangeReadCloser{RangeReader: resultRangeReader, Closers: remoteClosers}, + Expiration: &expiration, + }, nil +} diff --git a/drivers/s3/doge.go b/drivers/s3/doge.go new file mode 100644 index 0000000000000000000000000000000000000000..12a584ca4f225e8c83975398ac4611c9a2e2283e --- /dev/null +++ b/drivers/s3/doge.go @@ -0,0 +1,63 @@ +package s3 + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/hex" + "encoding/json" + "io" + "net/http" + "strings" +) + +type TmpTokenResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data TmpTokenResponseData `json:"data,omitempty"` +} +type TmpTokenResponseData struct { + Credentials Credentials `json:"Credentials"` + ExpiredAt int `json:"ExpiredAt"` +} +type Credentials struct { + AccessKeyId string `json:"accessKeyId,omitempty"` + SecretAccessKey string `json:"secretAccessKey,omitempty"` + SessionToken string `json:"sessionToken,omitempty"` +} + +func getCredentials(AccessKey, SecretKey string) (rst Credentials, err error) { + apiPath := "/auth/tmp_token.json" + reqBody, err := json.Marshal(map[string]interface{}{"channel": "OSS_FULL", "scopes": []string{"*"}}) + if err != nil { + return rst, err + } + + signStr := apiPath + "\n" + string(reqBody) + hmacObj := hmac.New(sha1.New, []byte(SecretKey)) + hmacObj.Write([]byte(signStr)) + sign := hex.EncodeToString(hmacObj.Sum(nil)) + Authorization := "TOKEN " + AccessKey + ":" + sign + + req, err := http.NewRequest("POST", "https://api.dogecloud.com"+apiPath, strings.NewReader(string(reqBody))) + if err != nil { + return rst, err + } + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", Authorization) + client := http.Client{} + resp, err := client.Do(req) + if err != nil { + return rst, err + } + defer resp.Body.Close() + ret, err := io.ReadAll(resp.Body) + if err != nil { + return rst, err + } + var tmpTokenResp TmpTokenResponse + err = json.Unmarshal(ret, &tmpTokenResp) + if err != nil { + return rst, err + } + return tmpTokenResp.Data.Credentials, nil +} diff --git a/drivers/s3/driver.go b/drivers/s3/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..82c050a1fe8e259c0cfd250c3dbc22a37dc24262 --- /dev/null +++ b/drivers/s3/driver.go @@ -0,0 +1,184 @@ +package s3 + +import ( + "bytes" + "context" + "fmt" + "github.com/alist-org/alist/v3/server/common" + "io" + "net/url" + stdpath "path" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/pkg/cron" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + log "github.com/sirupsen/logrus" +) + +type S3 struct { + model.Storage + Addition + Session *session.Session + client *s3.S3 + linkClient *s3.S3 + + config driver.Config + cron *cron.Cron +} + +func (d *S3) Config() driver.Config { + return d.config +} + +func (d *S3) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *S3) Init(ctx context.Context) error { + if d.Region == "" { + d.Region = "alist" + } + if d.config.Name == "Doge" { + // 多吉云每次临时生成的秘钥有效期为 2h,所以这里设置为 118 分钟重新生成一次 + d.cron = cron.NewCron(time.Minute * 118) + d.cron.Do(func() { + err := d.initSession() + if err != nil { + log.Errorln("Doge init session error:", err) + } + d.client = d.getClient(false) + d.linkClient = d.getClient(true) + }) + } + err := d.initSession() + if err != nil { + return err + } + d.client = d.getClient(false) + d.linkClient = d.getClient(true) + return nil +} + +func (d *S3) Drop(ctx context.Context) error { + if d.cron != nil { + d.cron.Stop() + } + return nil +} + +func (d *S3) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if d.ListObjectVersion == "v2" { + return d.listV2(dir.GetPath(), args) + } + return d.listV1(dir.GetPath(), args) +} + +func (d *S3) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + path := getKey(file.GetPath(), false) + filename := stdpath.Base(path) + disposition := fmt.Sprintf(`attachment; filename*=UTF-8''%s`, url.PathEscape(filename)) + if d.AddFilenameToDisposition { + disposition = fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, filename, url.PathEscape(filename)) + } + input := &s3.GetObjectInput{ + Bucket: &d.Bucket, + Key: &path, + //ResponseContentDisposition: &disposition, + } + if d.CustomHost == "" { + input.ResponseContentDisposition = &disposition + } + req, _ := d.linkClient.GetObjectRequest(input) + var link model.Link + var err error + if d.CustomHost != "" { + if d.EnableCustomHostPresign { + link.URL, err = req.Presign(time.Hour * time.Duration(d.SignURLExpire)) + } else { + err = req.Build() + link.URL = req.HTTPRequest.URL.String() + } + if d.RemoveBucket { + link.URL = strings.Replace(link.URL, "/"+d.Bucket, "", 1) + } + } else { + if common.ShouldProxy(d, filename) { + err = req.Sign() + link.URL = req.HTTPRequest.URL.String() + link.Header = req.HTTPRequest.Header + } else { + link.URL, err = req.Presign(time.Hour * time.Duration(d.SignURLExpire)) + } + } + if err != nil { + return nil, err + } + return &link, nil +} + +func (d *S3) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + return d.Put(ctx, &model.Object{ + Path: stdpath.Join(parentDir.GetPath(), dirName), + }, &stream.FileStream{ + Obj: &model.Object{ + Name: getPlaceholderName(d.Placeholder), + Modified: time.Now(), + }, + Reader: io.NopCloser(bytes.NewReader([]byte{})), + Mimetype: "application/octet-stream", + }, func(float64) {}) +} + +func (d *S3) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + err := d.Copy(ctx, srcObj, dstDir) + if err != nil { + return err + } + return d.Remove(ctx, srcObj) +} + +func (d *S3) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + err := d.copy(ctx, srcObj.GetPath(), stdpath.Join(stdpath.Dir(srcObj.GetPath()), newName), srcObj.IsDir()) + if err != nil { + return err + } + return d.Remove(ctx, srcObj) +} + +func (d *S3) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return d.copy(ctx, srcObj.GetPath(), stdpath.Join(dstDir.GetPath(), srcObj.GetName()), srcObj.IsDir()) +} + +func (d *S3) Remove(ctx context.Context, obj model.Obj) error { + if obj.IsDir() { + return d.removeDir(ctx, obj.GetPath()) + } + return d.removeFile(obj.GetPath()) +} + +func (d *S3) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + uploader := s3manager.NewUploader(d.Session) + if stream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { + uploader.PartSize = stream.GetSize() / (s3manager.MaxUploadParts - 1) + } + key := getKey(stdpath.Join(dstDir.GetPath(), stream.GetName()), false) + contentType := stream.GetMimetype() + log.Debugln("key:", key) + input := &s3manager.UploadInput{ + Bucket: &d.Bucket, + Key: &key, + Body: stream, + ContentType: &contentType, + } + _, err := uploader.UploadWithContext(ctx, input) + return err +} + +var _ driver.Driver = (*S3)(nil) diff --git a/drivers/s3/meta.go b/drivers/s3/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..4de4b60a6902933d4a9b30457d99b5c4414fdbac --- /dev/null +++ b/drivers/s3/meta.go @@ -0,0 +1,47 @@ +package s3 + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + Bucket string `json:"bucket" required:"true"` + Endpoint string `json:"endpoint" required:"true"` + Region string `json:"region"` + AccessKeyID string `json:"access_key_id" required:"true"` + SecretAccessKey string `json:"secret_access_key" required:"true"` + SessionToken string `json:"session_token"` + CustomHost string `json:"custom_host"` + EnableCustomHostPresign bool `json:"enable_custom_host_presign"` + SignURLExpire int `json:"sign_url_expire" type:"number" default:"4"` + Placeholder string `json:"placeholder"` + ForcePathStyle bool `json:"force_path_style"` + ListObjectVersion string `json:"list_object_version" type:"select" options:"v1,v2" default:"v1"` + RemoveBucket bool `json:"remove_bucket" help:"Remove bucket name from path when using custom host."` + AddFilenameToDisposition bool `json:"add_filename_to_disposition" help:"Add filename to Content-Disposition header."` +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &S3{ + config: driver.Config{ + Name: "S3", + DefaultRoot: "/", + LocalSort: true, + CheckStatus: true, + }, + } + }) + op.RegisterDriver(func() driver.Driver { + return &S3{ + config: driver.Config{ + Name: "Doge", + DefaultRoot: "/", + LocalSort: true, + CheckStatus: true, + }, + } + }) +} diff --git a/drivers/s3/types.go b/drivers/s3/types.go new file mode 100644 index 0000000000000000000000000000000000000000..3ed7f97237d3800e695e2c39364425ba7bd56e1b --- /dev/null +++ b/drivers/s3/types.go @@ -0,0 +1 @@ +package s3 diff --git a/drivers/s3/util.go b/drivers/s3/util.go new file mode 100644 index 0000000000000000000000000000000000000000..31e658bdcab76003c8feccc192085287c9504b5b --- /dev/null +++ b/drivers/s3/util.go @@ -0,0 +1,257 @@ +package s3 + +import ( + "context" + "errors" + "net/http" + "path" + "strings" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + log "github.com/sirupsen/logrus" +) + +// do others that not defined in Driver interface + +func (d *S3) initSession() error { + var err error + accessKeyID, secretAccessKey, sessionToken := d.AccessKeyID, d.SecretAccessKey, d.SessionToken + if d.config.Name == "Doge" { + credentialsTmp, err := getCredentials(d.AccessKeyID, d.SecretAccessKey) + if err != nil { + return err + } + accessKeyID, secretAccessKey, sessionToken = credentialsTmp.AccessKeyId, credentialsTmp.SecretAccessKey, credentialsTmp.SessionToken + } + cfg := &aws.Config{ + Credentials: credentials.NewStaticCredentials(accessKeyID, secretAccessKey, sessionToken), + Region: &d.Region, + Endpoint: &d.Endpoint, + S3ForcePathStyle: aws.Bool(d.ForcePathStyle), + } + d.Session, err = session.NewSession(cfg) + return err +} + +func (d *S3) getClient(link bool) *s3.S3 { + client := s3.New(d.Session) + if link && d.CustomHost != "" { + client.Handlers.Build.PushBack(func(r *request.Request) { + if r.HTTPRequest.Method != http.MethodGet { + return + } + //判断CustomHost是否以http://或https://开头 + split := strings.SplitN(d.CustomHost, "://", 2) + if utils.SliceContains([]string{"http", "https"}, split[0]) { + r.HTTPRequest.URL.Scheme = split[0] + r.HTTPRequest.URL.Host = split[1] + } else { + r.HTTPRequest.URL.Host = d.CustomHost + } + }) + } + return client +} + +func getKey(path string, dir bool) string { + path = strings.TrimPrefix(path, "/") + if path != "" && dir { + path += "/" + } + return path +} + +var defaultPlaceholderName = ".alist" + +func getPlaceholderName(placeholder string) string { + if placeholder == "" { + return defaultPlaceholderName + } + return placeholder +} + +func (d *S3) listV1(prefix string, args model.ListArgs) ([]model.Obj, error) { + prefix = getKey(prefix, true) + log.Debugf("list: %s", prefix) + files := make([]model.Obj, 0) + marker := "" + for { + input := &s3.ListObjectsInput{ + Bucket: &d.Bucket, + Marker: &marker, + Prefix: &prefix, + Delimiter: aws.String("/"), + } + listObjectsResult, err := d.client.ListObjects(input) + if err != nil { + return nil, err + } + for _, object := range listObjectsResult.CommonPrefixes { + name := path.Base(strings.Trim(*object.Prefix, "/")) + file := model.Object{ + //Id: *object.Key, + Name: name, + Modified: d.Modified, + IsFolder: true, + } + files = append(files, &file) + } + for _, object := range listObjectsResult.Contents { + name := path.Base(*object.Key) + if !args.S3ShowPlaceholder && (name == getPlaceholderName(d.Placeholder) || name == d.Placeholder) { + continue + } + file := model.Object{ + //Id: *object.Key, + Name: name, + Size: *object.Size, + Modified: *object.LastModified, + } + files = append(files, &file) + } + if listObjectsResult.IsTruncated == nil { + return nil, errors.New("IsTruncated nil") + } + if *listObjectsResult.IsTruncated { + marker = *listObjectsResult.NextMarker + } else { + break + } + } + return files, nil +} + +func (d *S3) listV2(prefix string, args model.ListArgs) ([]model.Obj, error) { + prefix = getKey(prefix, true) + files := make([]model.Obj, 0) + var continuationToken, startAfter *string + for { + input := &s3.ListObjectsV2Input{ + Bucket: &d.Bucket, + ContinuationToken: continuationToken, + Prefix: &prefix, + Delimiter: aws.String("/"), + StartAfter: startAfter, + } + listObjectsResult, err := d.client.ListObjectsV2(input) + if err != nil { + return nil, err + } + log.Debugf("resp: %+v", listObjectsResult) + for _, object := range listObjectsResult.CommonPrefixes { + name := path.Base(strings.Trim(*object.Prefix, "/")) + file := model.Object{ + //Id: *object.Key, + Name: name, + Modified: d.Modified, + IsFolder: true, + } + files = append(files, &file) + } + for _, object := range listObjectsResult.Contents { + if strings.HasSuffix(*object.Key, "/") { + continue + } + name := path.Base(*object.Key) + if !args.S3ShowPlaceholder && (name == getPlaceholderName(d.Placeholder) || name == d.Placeholder) { + continue + } + file := model.Object{ + //Id: *object.Key, + Name: name, + Size: *object.Size, + Modified: *object.LastModified, + } + files = append(files, &file) + } + if !aws.BoolValue(listObjectsResult.IsTruncated) { + break + } + if listObjectsResult.NextContinuationToken != nil { + continuationToken = listObjectsResult.NextContinuationToken + continue + } + if len(listObjectsResult.Contents) == 0 { + break + } + startAfter = listObjectsResult.Contents[len(listObjectsResult.Contents)-1].Key + } + return files, nil +} + +func (d *S3) copy(ctx context.Context, src string, dst string, isDir bool) error { + if isDir { + return d.copyDir(ctx, src, dst) + } + return d.copyFile(ctx, src, dst) +} + +func (d *S3) copyFile(ctx context.Context, src string, dst string) error { + srcKey := getKey(src, false) + dstKey := getKey(dst, false) + input := &s3.CopyObjectInput{ + Bucket: &d.Bucket, + CopySource: aws.String("/" + d.Bucket + "/" + srcKey), + Key: &dstKey, + } + _, err := d.client.CopyObject(input) + return err +} + +func (d *S3) copyDir(ctx context.Context, src string, dst string) error { + objs, err := op.List(ctx, d, src, model.ListArgs{S3ShowPlaceholder: true}) + if err != nil { + return err + } + for _, obj := range objs { + cSrc := path.Join(src, obj.GetName()) + cDst := path.Join(dst, obj.GetName()) + if obj.IsDir() { + err = d.copyDir(ctx, cSrc, cDst) + } else { + err = d.copyFile(ctx, cSrc, cDst) + } + if err != nil { + return err + } + } + return nil +} + +func (d *S3) removeDir(ctx context.Context, src string) error { + objs, err := op.List(ctx, d, src, model.ListArgs{}) + if err != nil { + return err + } + for _, obj := range objs { + cSrc := path.Join(src, obj.GetName()) + if obj.IsDir() { + err = d.removeDir(ctx, cSrc) + } else { + err = d.removeFile(cSrc) + } + if err != nil { + return err + } + } + _ = d.removeFile(path.Join(src, getPlaceholderName(d.Placeholder))) + _ = d.removeFile(path.Join(src, d.Placeholder)) + return nil +} + +func (d *S3) removeFile(src string) error { + key := getKey(src, false) + input := &s3.DeleteObjectInput{ + Bucket: &d.Bucket, + Key: &key, + } + _, err := d.client.DeleteObject(input) + return err +} diff --git a/drivers/seafile/driver.go b/drivers/seafile/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..6d1f16dad3ba4f2f86dc99e443f1e38d26a3587a --- /dev/null +++ b/drivers/seafile/driver.go @@ -0,0 +1,226 @@ +package seafile + +import ( + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +type Seafile struct { + model.Storage + Addition + + authorization string + libraryMap map[string]*LibraryInfo +} + +func (d *Seafile) Config() driver.Config { + return config +} + +func (d *Seafile) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Seafile) Init(ctx context.Context) error { + d.Address = strings.TrimSuffix(d.Address, "/") + d.RootFolderPath = utils.FixAndCleanPath(d.RootFolderPath) + d.libraryMap = make(map[string]*LibraryInfo) + return d.getToken() +} + +func (d *Seafile) Drop(ctx context.Context) error { + return nil +} + +func (d *Seafile) List(ctx context.Context, dir model.Obj, args model.ListArgs) (result []model.Obj, err error) { + path := dir.GetPath() + if path == d.RootFolderPath { + libraries, err := d.listLibraries() + if err != nil { + return nil, err + } + if path == "/" && d.RepoId == "" { + return utils.SliceConvert(libraries, func(f LibraryItemResp) (model.Obj, error) { + return &model.Object{ + Name: f.Name, + Modified: time.Unix(f.Modified, 0), + Size: f.Size, + IsFolder: true, + }, nil + }) + } + } + var repo *LibraryInfo + repo, path, err = d.getRepoAndPath(path) + if err != nil { + return nil, err + } + if repo.Encrypted { + err = d.decryptLibrary(repo) + if err != nil { + return nil, err + } + } + var resp []RepoDirItemResp + _, err = d.request(http.MethodGet, fmt.Sprintf("/api2/repos/%s/dir/", repo.Id), func(req *resty.Request) { + req.SetResult(&resp).SetQueryParams(map[string]string{ + "p": path, + }) + }) + if err != nil { + return nil, err + } + return utils.SliceConvert(resp, func(f RepoDirItemResp) (model.Obj, error) { + return &model.ObjThumb{ + Object: model.Object{ + Name: f.Name, + Modified: time.Unix(f.Modified, 0), + Size: f.Size, + IsFolder: f.Type == "dir", + }, + // Thumbnail: model.Thumbnail{Thumbnail: f.Thumb}, + }, nil + }) +} + +func (d *Seafile) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + repo, path, err := d.getRepoAndPath(file.GetPath()) + if err != nil { + return nil, err + } + res, err := d.request(http.MethodGet, fmt.Sprintf("/api2/repos/%s/file/", repo.Id), func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "p": path, + "reuse": "1", + }) + }) + if err != nil { + return nil, err + } + u := string(res) + u = u[1 : len(u)-1] // remove quotes + return &model.Link{URL: u}, nil +} + +func (d *Seafile) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + repo, path, err := d.getRepoAndPath(parentDir.GetPath()) + if err != nil { + return err + } + path, _ = utils.JoinBasePath(path, dirName) + _, err = d.request(http.MethodPost, fmt.Sprintf("/api2/repos/%s/dir/", repo.Id), func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "p": path, + }).SetFormData(map[string]string{ + "operation": "mkdir", + }) + }) + return err +} + +func (d *Seafile) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + repo, path, err := d.getRepoAndPath(srcObj.GetPath()) + if err != nil { + return err + } + dstRepo, dstPath, err := d.getRepoAndPath(dstDir.GetPath()) + if err != nil { + return err + } + _, err = d.request(http.MethodPost, fmt.Sprintf("/api2/repos/%s/file/", repo.Id), func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "p": path, + }).SetFormData(map[string]string{ + "operation": "move", + "dst_repo": dstRepo.Id, + "dst_dir": dstPath, + }) + }, true) + return err +} + +func (d *Seafile) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + repo, path, err := d.getRepoAndPath(srcObj.GetPath()) + if err != nil { + return err + } + _, err = d.request(http.MethodPost, fmt.Sprintf("/api2/repos/%s/file/", repo.Id), func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "p": path, + }).SetFormData(map[string]string{ + "operation": "rename", + "newname": newName, + }) + }, true) + return err +} + +func (d *Seafile) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + repo, path, err := d.getRepoAndPath(srcObj.GetPath()) + if err != nil { + return err + } + dstRepo, dstPath, err := d.getRepoAndPath(dstDir.GetPath()) + if err != nil { + return err + } + _, err = d.request(http.MethodPost, fmt.Sprintf("/api2/repos/%s/file/", repo.Id), func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "p": path, + }).SetFormData(map[string]string{ + "operation": "copy", + "dst_repo": dstRepo.Id, + "dst_dir": dstPath, + }) + }) + return err +} + +func (d *Seafile) Remove(ctx context.Context, obj model.Obj) error { + repo, path, err := d.getRepoAndPath(obj.GetPath()) + if err != nil { + return err + } + _, err = d.request(http.MethodDelete, fmt.Sprintf("/api2/repos/%s/file/", repo.Id), func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "p": path, + }) + }) + return err +} + +func (d *Seafile) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + repo, path, err := d.getRepoAndPath(dstDir.GetPath()) + if err != nil { + return err + } + res, err := d.request(http.MethodGet, fmt.Sprintf("/api2/repos/%s/upload-link/", repo.Id), func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "p": path, + }) + }) + if err != nil { + return err + } + + u := string(res) + u = u[1 : len(u)-1] // remove quotes + _, err = d.request(http.MethodPost, u, func(req *resty.Request) { + req.SetFileReader("file", stream.GetName(), stream). + SetFormData(map[string]string{ + "parent_dir": path, + "replace": "1", + }) + }) + return err +} + +var _ driver.Driver = (*Seafile)(nil) diff --git a/drivers/seafile/meta.go b/drivers/seafile/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..fd5255f592b70876b99554324fbb920733d12d31 --- /dev/null +++ b/drivers/seafile/meta.go @@ -0,0 +1,28 @@ +package seafile + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + + Address string `json:"address" required:"true"` + UserName string `json:"username" required:"false"` + Password string `json:"password" required:"false"` + Token string `json:"token" required:"false"` + RepoId string `json:"repoId" required:"false"` + RepoPwd string `json:"repoPwd" required:"false"` +} + +var config = driver.Config{ + Name: "Seafile", + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Seafile{} + }) +} diff --git a/drivers/seafile/types.go b/drivers/seafile/types.go new file mode 100644 index 0000000000000000000000000000000000000000..47cb322df4af5b521bc0c230cc48189dca1f0880 --- /dev/null +++ b/drivers/seafile/types.go @@ -0,0 +1,44 @@ +package seafile + +import "time" + +type AuthTokenResp struct { + Token string `json:"token"` +} + +type RepoItemResp struct { + Id string `json:"id"` + Type string `json:"type"` // repo, dir, file + Name string `json:"name"` + Size int64 `json:"size"` + Modified int64 `json:"mtime"` + Permission string `json:"permission"` +} + +type LibraryItemResp struct { + RepoItemResp + OwnerContactEmail string `json:"owner_contact_email"` + OwnerName string `json:"owner_name"` + Owner string `json:"owner"` + ModifierEmail string `json:"modifier_email"` + ModifierContactEmail string `json:"modifier_contact_email"` + ModifierName string `json:"modifier_name"` + Virtual bool `json:"virtual"` + MtimeRelative string `json:"mtime_relative"` + Encrypted bool `json:"encrypted"` + Version int `json:"version"` + HeadCommitId string `json:"head_commit_id"` + Root string `json:"root"` + Salt string `json:"salt"` + SizeFormatted string `json:"size_formatted"` +} + +type RepoDirItemResp struct { + RepoItemResp +} + +type LibraryInfo struct { + LibraryItemResp + decryptedTime time.Time + decryptedSuccess bool +} \ No newline at end of file diff --git a/drivers/seafile/util.go b/drivers/seafile/util.go new file mode 100644 index 0000000000000000000000000000000000000000..89b7b0fc39f3b345f1e3f293ff8888073c85d340 --- /dev/null +++ b/drivers/seafile/util.go @@ -0,0 +1,178 @@ +package seafile + +import ( + "errors" + "fmt" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/pkg/utils" + "net/http" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/go-resty/resty/v2" +) + +func (d *Seafile) getToken() error { + if d.Token != "" { + d.authorization = fmt.Sprintf("Token %s", d.Token) + return nil + } + var authResp AuthTokenResp + res, err := base.RestyClient.R(). + SetResult(&authResp). + SetFormData(map[string]string{ + "username": d.UserName, + "password": d.Password, + }). + Post(d.Address + "/api2/auth-token/") + if err != nil { + return err + } + if res.StatusCode() >= 400 { + return fmt.Errorf("get token failed: %s", res.String()) + } + d.authorization = fmt.Sprintf("Token %s", authResp.Token) + return nil +} + +func (d *Seafile) request(method string, pathname string, callback base.ReqCallback, noRedirect ...bool) ([]byte, error) { + full := pathname + if !strings.HasPrefix(pathname, "http") { + full = d.Address + pathname + } + req := base.RestyClient.R() + if len(noRedirect) > 0 && noRedirect[0] { + req = base.NoRedirectClient.R() + } + req.SetHeader("Authorization", d.authorization) + callback(req) + var ( + res *resty.Response + err error + ) + for i := 0; i < 2; i++ { + res, err = req.Execute(method, full) + if err != nil { + return nil, err + } + if res.StatusCode() != 401 { // Unauthorized + break + } + err = d.getToken() + if err != nil { + return nil, err + } + } + if res.StatusCode() >= 400 { + return nil, fmt.Errorf("request failed: %s", res.String()) + } + return res.Body(), nil +} + +func (d *Seafile) getRepoAndPath(fullPath string) (repo *LibraryInfo, path string, err error) { + libraryMap := d.libraryMap + repoId := d.Addition.RepoId + if repoId != "" { + if len(repoId) == 36 /* uuid */ { + for _, library := range libraryMap { + if library.Id == repoId { + return library, fullPath, nil + } + } + } + } else { + var repoName string + str := fullPath[1:] + pos := strings.IndexRune(str, '/') + if pos == -1 { + repoName = str + } else { + repoName = str[:pos] + } + path = utils.FixAndCleanPath(fullPath[1+len(repoName):]) + if library, ok := libraryMap[repoName]; ok { + return library, path, nil + } + } + return nil, "", errs.ObjectNotFound +} + +func (d *Seafile) listLibraries() (resp []LibraryItemResp, err error) { + repoId := d.Addition.RepoId + if repoId == "" { + _, err = d.request(http.MethodGet, "/api2/repos/", func(req *resty.Request) { + req.SetResult(&resp) + }) + } else { + var oneResp LibraryItemResp + _, err = d.request(http.MethodGet, fmt.Sprintf("/api2/repos/%s/", repoId), func(req *resty.Request) { + req.SetResult(&oneResp) + }) + if err == nil { + resp = append(resp, oneResp) + } + } + if err != nil { + return nil, err + } + libraryMap := make(map[string]*LibraryInfo) + var putLibraryMap func(library LibraryItemResp, index int) + putLibraryMap = func(library LibraryItemResp, index int) { + name := library.Name + if index > 0 { + name = fmt.Sprintf("%s (%d)", name, index) + } + if _, exist := libraryMap[name]; exist { + putLibraryMap(library, index+1) + } else { + libraryInfo := LibraryInfo{} + data, _ := utils.Json.Marshal(library) + _ = utils.Json.Unmarshal(data, &libraryInfo) + libraryMap[name] = &libraryInfo + } + } + for _, library := range resp { + putLibraryMap(library, 0) + } + d.libraryMap = libraryMap + return resp, nil +} + +var repoPwdNotConfigured = errors.New("library password not configured") +var repoPwdIncorrect = errors.New("library password is incorrect") + +func (d *Seafile) decryptLibrary(repo *LibraryInfo) (err error) { + if !repo.Encrypted { + return nil + } + if d.RepoPwd == "" { + return repoPwdNotConfigured + } + now := time.Now() + decryptedTime := repo.decryptedTime + if repo.decryptedSuccess { + if now.Sub(decryptedTime).Minutes() <= 30 { + return nil + } + } else { + if now.Sub(decryptedTime).Seconds() <= 10 { + return repoPwdIncorrect + } + } + var resp string + _, err = d.request(http.MethodPost, fmt.Sprintf("/api2/repos/%s/", repo.Id), func(req *resty.Request) { + req.SetResult(&resp).SetFormData(map[string]string{ + "password": d.RepoPwd, + }) + }) + repo.decryptedTime = time.Now() + if err != nil || !strings.Contains(resp, "success") { + repo.decryptedSuccess = false + return err + } + repo.decryptedSuccess = true + return nil +} + + diff --git a/drivers/sftp/driver.go b/drivers/sftp/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..1f216598d2d0780126103eda28ad5c8d92b7c28a --- /dev/null +++ b/drivers/sftp/driver.go @@ -0,0 +1,118 @@ +package sftp + +import ( + "context" + "os" + "path" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/sftp" + log "github.com/sirupsen/logrus" +) + +type SFTP struct { + model.Storage + Addition + client *sftp.Client + clientConnectionError error +} + +func (d *SFTP) Config() driver.Config { + return config +} + +func (d *SFTP) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *SFTP) Init(ctx context.Context) error { + return d.initClient() +} + +func (d *SFTP) Drop(ctx context.Context) error { + if d.client != nil { + _ = d.client.Close() + } + return nil +} + +func (d *SFTP) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if err := d.clientReconnectOnConnectionError(); err != nil { + return nil, err + } + log.Debugf("[sftp] list dir: %s", dir.GetPath()) + files, err := d.client.ReadDir(dir.GetPath()) + if err != nil { + return nil, err + } + objs, err := utils.SliceConvert(files, func(src os.FileInfo) (model.Obj, error) { + return d.fileToObj(src, dir.GetPath()) + }) + return objs, err +} + +func (d *SFTP) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if err := d.clientReconnectOnConnectionError(); err != nil { + return nil, err + } + remoteFile, err := d.client.Open(file.GetPath()) + if err != nil { + return nil, err + } + link := &model.Link{ + MFile: remoteFile, + } + return link, nil +} + +func (d *SFTP) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + if err := d.clientReconnectOnConnectionError(); err != nil { + return err + } + return d.client.MkdirAll(path.Join(parentDir.GetPath(), dirName)) +} + +func (d *SFTP) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + if err := d.clientReconnectOnConnectionError(); err != nil { + return err + } + return d.client.Rename(srcObj.GetPath(), path.Join(dstDir.GetPath(), srcObj.GetName())) +} + +func (d *SFTP) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + if err := d.clientReconnectOnConnectionError(); err != nil { + return err + } + return d.client.Rename(srcObj.GetPath(), path.Join(path.Dir(srcObj.GetPath()), newName)) +} + +func (d *SFTP) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotSupport +} + +func (d *SFTP) Remove(ctx context.Context, obj model.Obj) error { + if err := d.clientReconnectOnConnectionError(); err != nil { + return err + } + return d.remove(obj.GetPath()) +} + +func (d *SFTP) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + if err := d.clientReconnectOnConnectionError(); err != nil { + return err + } + dstFile, err := d.client.Create(path.Join(dstDir.GetPath(), stream.GetName())) + if err != nil { + return err + } + defer func() { + _ = dstFile.Close() + }() + err = utils.CopyWithCtx(ctx, dstFile, stream, stream.GetSize(), up) + return err +} + +var _ driver.Driver = (*SFTP)(nil) diff --git a/drivers/sftp/meta.go b/drivers/sftp/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..9b1665679cdbf788cb62454e49d78f31a90abc38 --- /dev/null +++ b/drivers/sftp/meta.go @@ -0,0 +1,30 @@ +package sftp + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + Address string `json:"address" required:"true"` + Username string `json:"username" required:"true"` + PrivateKey string `json:"private_key" type:"text"` + Password string `json:"password"` + Passphrase string `json:"passphrase"` + driver.RootPath + IgnoreSymlinkError bool `json:"ignore_symlink_error" default:"false" info:"Ignore symlink error"` +} + +var config = driver.Config{ + Name: "SFTP", + LocalSort: true, + OnlyLocal: true, + DefaultRoot: "/", + CheckStatus: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &SFTP{} + }) +} diff --git a/drivers/sftp/types.go b/drivers/sftp/types.go new file mode 100644 index 0000000000000000000000000000000000000000..493e884c151a57c5650cad95978dc08d75d4aadc --- /dev/null +++ b/drivers/sftp/types.go @@ -0,0 +1,53 @@ +package sftp + +import ( + "os" + stdpath "path" + "strings" + + "github.com/alist-org/alist/v3/internal/model" + log "github.com/sirupsen/logrus" +) + +func (d *SFTP) fileToObj(f os.FileInfo, dir string) (model.Obj, error) { + symlink := f.Mode()&os.ModeSymlink != 0 + if !symlink { + return &model.Object{ + Name: f.Name(), + Size: f.Size(), + Modified: f.ModTime(), + IsFolder: f.IsDir(), + }, nil + } + path := stdpath.Join(dir, f.Name()) + // set target path + target, err := d.client.ReadLink(path) + if err != nil { + return nil, err + } + if !strings.HasPrefix(target, "/") { + target = stdpath.Join(dir, target) + } + _f, err := d.client.Stat(target) + if err != nil { + if d.IgnoreSymlinkError { + return &model.Object{ + Name: f.Name(), + Size: f.Size(), + Modified: f.ModTime(), + IsFolder: f.IsDir(), + }, nil + } + return nil, err + } + // set basic info + obj := &model.Object{ + Name: f.Name(), + Size: _f.Size(), + Modified: _f.ModTime(), + IsFolder: _f.IsDir(), + Path: target, + } + log.Debugf("[sftp] obj: %+v, is symlink: %v", obj, symlink) + return obj, nil +} diff --git a/drivers/sftp/util.go b/drivers/sftp/util.go new file mode 100644 index 0000000000000000000000000000000000000000..53f9c379e04195e07dfb4012700d7ab1c69f4e95 --- /dev/null +++ b/drivers/sftp/util.go @@ -0,0 +1,96 @@ +package sftp + +import ( + "path" + + "github.com/pkg/sftp" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" +) + +// do others that not defined in Driver interface + +func (d *SFTP) initClient() error { + var auth ssh.AuthMethod + if len(d.PrivateKey) > 0 { + var err error + var signer ssh.Signer + if len(d.Passphrase) > 0 { + signer, err = ssh.ParsePrivateKeyWithPassphrase([]byte(d.PrivateKey), []byte(d.Passphrase)) + } else { + signer, err = ssh.ParsePrivateKey([]byte(d.PrivateKey)) + } + if err != nil { + return err + } + auth = ssh.PublicKeys(signer) + } else { + auth = ssh.Password(d.Password) + } + config := &ssh.ClientConfig{ + User: d.Username, + Auth: []ssh.AuthMethod{auth}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + conn, err := ssh.Dial("tcp", d.Address, config) + if err != nil { + return err + } + d.client, err = sftp.NewClient(conn) + if err == nil { + d.clientConnectionError = nil + go func(d *SFTP) { + d.clientConnectionError = d.client.Wait() + }(d) + } + return err +} + +func (d *SFTP) clientReconnectOnConnectionError() error { + err := d.clientConnectionError + if err == nil { + return nil + } + log.Debugf("[sftp] discarding closed sftp connection: %v", err) + _ = d.client.Close() + err = d.initClient() + return err +} + +func (d *SFTP) remove(remotePath string) error { + f, err := d.client.Stat(remotePath) + if err != nil { + return nil + } + if f.IsDir() { + return d.removeDirectory(remotePath) + } else { + return d.removeFile(remotePath) + } +} + +func (d *SFTP) removeDirectory(remotePath string) error { + remoteFiles, err := d.client.ReadDir(remotePath) + if err != nil { + return err + } + for _, backupDir := range remoteFiles { + remoteFilePath := path.Join(remotePath, backupDir.Name()) + if backupDir.IsDir() { + err := d.removeDirectory(remoteFilePath) + if err != nil { + return err + } + } else { + err := d.removeFile(remoteFilePath) + if err != nil { + return err + } + } + } + return d.client.RemoveDirectory(remotePath) +} + +func (d *SFTP) removeFile(remotePath string) error { + return d.client.Remove(path.Join(remotePath)) +} diff --git a/drivers/smb/driver.go b/drivers/smb/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..9632f24e0ebcdc89902d1ed44155f2db1013c190 --- /dev/null +++ b/drivers/smb/driver.go @@ -0,0 +1,200 @@ +package smb + +import ( + "context" + "errors" + "path/filepath" + "strings" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + + "github.com/hirochachacha/go-smb2" +) + +type SMB struct { + lastConnTime int64 + model.Storage + Addition + fs *smb2.Share +} + +func (d *SMB) Config() driver.Config { + return config +} + +func (d *SMB) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *SMB) Init(ctx context.Context) error { + if strings.Index(d.Addition.Address, ":") < 0 { + d.Addition.Address = d.Addition.Address + ":445" + } + return d.initFS() +} + +func (d *SMB) Drop(ctx context.Context) error { + if d.fs != nil { + _ = d.fs.Umount() + } + return nil +} + +func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if err := d.checkConn(); err != nil { + return nil, err + } + fullPath := dir.GetPath() + rawFiles, err := d.fs.ReadDir(fullPath) + if err != nil { + d.cleanLastConnTime() + return nil, err + } + d.updateLastConnTime() + var files []model.Obj + for _, f := range rawFiles { + file := model.ObjThumb{ + Object: model.Object{ + Name: f.Name(), + Modified: f.ModTime(), + Size: f.Size(), + IsFolder: f.IsDir(), + Ctime: f.(*smb2.FileStat).CreationTime, + }, + } + files = append(files, &file) + } + return files, nil +} + +func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if err := d.checkConn(); err != nil { + return nil, err + } + fullPath := file.GetPath() + remoteFile, err := d.fs.Open(fullPath) + if err != nil { + d.cleanLastConnTime() + return nil, err + } + link := &model.Link{ + MFile: remoteFile, + } + d.updateLastConnTime() + return link, nil +} + +func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + if err := d.checkConn(); err != nil { + return err + } + fullPath := filepath.Join(parentDir.GetPath(), dirName) + err := d.fs.MkdirAll(fullPath, 0700) + if err != nil { + d.cleanLastConnTime() + return err + } + d.updateLastConnTime() + return nil +} + +func (d *SMB) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + if err := d.checkConn(); err != nil { + return err + } + srcPath := srcObj.GetPath() + dstPath := filepath.Join(dstDir.GetPath(), srcObj.GetName()) + err := d.fs.Rename(srcPath, dstPath) + if err != nil { + d.cleanLastConnTime() + return err + } + d.updateLastConnTime() + return nil +} + +func (d *SMB) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + if err := d.checkConn(); err != nil { + return err + } + srcPath := srcObj.GetPath() + dstPath := filepath.Join(filepath.Dir(srcPath), newName) + err := d.fs.Rename(srcPath, dstPath) + if err != nil { + d.cleanLastConnTime() + return err + } + d.updateLastConnTime() + return nil +} + +func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + if err := d.checkConn(); err != nil { + return err + } + srcPath := srcObj.GetPath() + dstPath := filepath.Join(dstDir.GetPath(), srcObj.GetName()) + var err error + if srcObj.IsDir() { + err = d.CopyDir(srcPath, dstPath) + } else { + err = d.CopyFile(srcPath, dstPath) + } + if err != nil { + d.cleanLastConnTime() + return err + } + d.updateLastConnTime() + return nil +} + +func (d *SMB) Remove(ctx context.Context, obj model.Obj) error { + if err := d.checkConn(); err != nil { + return err + } + var err error + fullPath := obj.GetPath() + if obj.IsDir() { + err = d.fs.RemoveAll(fullPath) + } else { + err = d.fs.Remove(fullPath) + } + if err != nil { + d.cleanLastConnTime() + return err + } + d.updateLastConnTime() + return nil +} + +func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + if err := d.checkConn(); err != nil { + return err + } + fullPath := filepath.Join(dstDir.GetPath(), stream.GetName()) + out, err := d.fs.Create(fullPath) + if err != nil { + d.cleanLastConnTime() + return err + } + d.updateLastConnTime() + defer func() { + _ = out.Close() + if errors.Is(err, context.Canceled) { + _ = d.fs.Remove(fullPath) + } + }() + err = utils.CopyWithCtx(ctx, out, stream, stream.GetSize(), up) + if err != nil { + return err + } + return nil +} + +//func (d *SMB) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*SMB)(nil) diff --git a/drivers/smb/meta.go b/drivers/smb/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..3386db2fb42d1716b68da13212d0abfd538256a7 --- /dev/null +++ b/drivers/smb/meta.go @@ -0,0 +1,28 @@ +package smb + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + Address string `json:"address" required:"true"` + Username string `json:"username" required:"true"` + Password string `json:"password"` + ShareName string `json:"share_name" required:"true"` +} + +var config = driver.Config{ + Name: "SMB", + LocalSort: true, + OnlyLocal: true, + DefaultRoot: ".", + NoCache: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &SMB{} + }) +} diff --git a/drivers/smb/types.go b/drivers/smb/types.go new file mode 100644 index 0000000000000000000000000000000000000000..161798ade6d04b116ce9659bf449a68ee1f90b2d --- /dev/null +++ b/drivers/smb/types.go @@ -0,0 +1 @@ +package smb diff --git a/drivers/smb/util.go b/drivers/smb/util.go new file mode 100644 index 0000000000000000000000000000000000000000..d9fbf6c5a5a32300ba77b0c7fd2ee21cad6ce6af --- /dev/null +++ b/drivers/smb/util.go @@ -0,0 +1,138 @@ +package smb + +import ( + "github.com/alist-org/alist/v3/pkg/utils" + "io/fs" + "net" + "os" + "path/filepath" + "sync/atomic" + "time" + + "github.com/hirochachacha/go-smb2" +) + +func (d *SMB) updateLastConnTime() { + atomic.StoreInt64(&d.lastConnTime, time.Now().Unix()) +} + +func (d *SMB) cleanLastConnTime() { + atomic.StoreInt64(&d.lastConnTime, 0) +} + +func (d *SMB) getLastConnTime() time.Time { + return time.Unix(atomic.LoadInt64(&d.lastConnTime), 0) +} + +func (d *SMB) initFS() error { + conn, err := net.Dial("tcp", d.Address) + if err != nil { + return err + } + dialer := &smb2.Dialer{ + Initiator: &smb2.NTLMInitiator{ + User: d.Username, + Password: d.Password, + }, + } + s, err := dialer.Dial(conn) + if err != nil { + return err + } + d.fs, err = s.Mount(d.ShareName) + if err != nil { + return err + } + d.updateLastConnTime() + return err +} + +func (d *SMB) checkConn() error { + if time.Since(d.getLastConnTime()) < 5*time.Minute { + return nil + } + if d.fs != nil { + _ = d.fs.Umount() + } + return d.initFS() +} + +// CopyFile File copies a single file from src to dst +func (d *SMB) CopyFile(src, dst string) error { + var err error + var srcfd *smb2.File + var dstfd *smb2.File + var srcinfo fs.FileInfo + + if srcfd, err = d.fs.Open(src); err != nil { + return err + } + defer srcfd.Close() + + if dstfd, err = d.CreateNestedFile(dst); err != nil { + return err + } + defer dstfd.Close() + + if _, err = utils.CopyWithBuffer(dstfd, srcfd); err != nil { + return err + } + if srcinfo, err = d.fs.Stat(src); err != nil { + return err + } + return d.fs.Chmod(dst, srcinfo.Mode()) +} + +// CopyDir Dir copies a whole directory recursively +func (d *SMB) CopyDir(src string, dst string) error { + var err error + var fds []fs.FileInfo + var srcinfo fs.FileInfo + + if srcinfo, err = d.fs.Stat(src); err != nil { + return err + } + if err = d.fs.MkdirAll(dst, srcinfo.Mode()); err != nil { + return err + } + if fds, err = d.fs.ReadDir(src); err != nil { + return err + } + for _, fd := range fds { + srcfp := filepath.Join(src, fd.Name()) + dstfp := filepath.Join(dst, fd.Name()) + + if fd.IsDir() { + if err = d.CopyDir(srcfp, dstfp); err != nil { + return err + } + } else { + if err = d.CopyFile(srcfp, dstfp); err != nil { + return err + } + } + } + return nil +} + +// Exists determine whether the file exists +func (d *SMB) Exists(name string) bool { + if _, err := d.fs.Stat(name); err != nil { + if os.IsNotExist(err) { + return false + } + } + return true +} + +// CreateNestedFile create nested file +func (d *SMB) CreateNestedFile(path string) (*smb2.File, error) { + basePath := filepath.Dir(path) + if !d.Exists(basePath) { + err := d.fs.MkdirAll(basePath, 0700) + if err != nil { + return nil, err + } + } + return d.fs.Create(path) +} diff --git a/drivers/teambition/driver.go b/drivers/teambition/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..c75d2ac00b6f07ec516d6e39fed3677da78dfc1f --- /dev/null +++ b/drivers/teambition/driver.go @@ -0,0 +1,163 @@ +package teambition + +import ( + "context" + "errors" + "github.com/alist-org/alist/v3/pkg/utils" + "net/http" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/go-resty/resty/v2" +) + +type Teambition struct { + model.Storage + Addition +} + +func (d *Teambition) Config() driver.Config { + return config +} + +func (d *Teambition) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Teambition) Init(ctx context.Context) error { + _, err := d.request("/api/v2/roles", http.MethodGet, nil, nil) + return err +} + +func (d *Teambition) Drop(ctx context.Context) error { + return nil +} + +func (d *Teambition) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + return d.getFiles(dir.GetID()) +} + +func (d *Teambition) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if u, ok := file.(model.URL); ok { + url := u.URL() + res, _ := base.NoRedirectClient.R().Get(url) + if res.StatusCode() == 302 { + url = res.Header().Get("location") + } + return &model.Link{URL: url}, nil + } + return nil, errors.New("can't convert obj to URL") +} + +func (d *Teambition) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + data := base.Json{ + "objectType": "collection", + "_projectId": d.ProjectID, + "_creatorId": "", + "created": "", + "updated": "", + "title": dirName, + "color": "blue", + "description": "", + "workCount": 0, + "collectionType": "", + "recentWorks": []interface{}{}, + "_parentId": parentDir.GetID(), + "subCount": nil, + } + _, err := d.request("/api/collections", http.MethodPost, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *Teambition) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + pre := "/api/works/" + if srcObj.IsDir() { + pre = "/api/collections/" + } + _, err := d.request(pre+srcObj.GetID()+"/move", http.MethodPut, func(req *resty.Request) { + req.SetBody(base.Json{ + "_parentId": dstDir.GetID(), + }) + }, nil) + return err +} + +func (d *Teambition) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + pre := "/api/works/" + data := base.Json{ + "fileName": newName, + } + if srcObj.IsDir() { + pre = "/api/collections/" + data = base.Json{ + "title": newName, + } + } + _, err := d.request(pre+srcObj.GetID(), http.MethodPut, func(req *resty.Request) { + req.SetBody(data) + }, nil) + return err +} + +func (d *Teambition) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + pre := "/api/works/" + if srcObj.IsDir() { + pre = "/api/collections/" + } + _, err := d.request(pre+srcObj.GetID()+"/fork", http.MethodPut, func(req *resty.Request) { + req.SetBody(base.Json{ + "_parentId": dstDir.GetID(), + }) + }, nil) + return err +} + +func (d *Teambition) Remove(ctx context.Context, obj model.Obj) error { + pre := "/api/works/" + if obj.IsDir() { + pre = "/api/collections/" + } + _, err := d.request(pre+obj.GetID()+"/archive", http.MethodPost, nil, nil) + return err +} + +func (d *Teambition) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + if d.UseS3UploadMethod { + return d.newUpload(ctx, dstDir, stream, up) + } + var ( + token string + err error + ) + if d.isInternational() { + res, err := d.request("/projects", http.MethodGet, nil, nil) + if err != nil { + return err + } + token = getBetweenStr(string(res), "strikerAuth":"", "","phoneForLogin") + } else { + res, err := d.request("/api/v2/users/me", http.MethodGet, nil, nil) + if err != nil { + return err + } + token = utils.Json.Get(res, "strikerAuth").ToString() + } + var newFile *FileUpload + if stream.GetSize() <= 20971520 { + // post upload + newFile, err = d.upload(ctx, stream, token) + } else { + // chunk upload + //err = base.ErrNotImplement + newFile, err = d.chunkUpload(ctx, stream, token, up) + } + if err != nil { + return err + } + return d.finishUpload(newFile, dstDir.GetID()) +} + +var _ driver.Driver = (*Teambition)(nil) diff --git a/drivers/teambition/help.go b/drivers/teambition/help.go new file mode 100644 index 0000000000000000000000000000000000000000..8581c3e827cad8dafd9be658cb33429403513dff --- /dev/null +++ b/drivers/teambition/help.go @@ -0,0 +1,18 @@ +package teambition + +import "strings" + +func getBetweenStr(str, start, end string) string { + n := strings.Index(str, start) + if n == -1 { + return "" + } + n = n + len(start) + str = string([]byte(str)[n:]) + m := strings.Index(str, end) + if m == -1 { + return "" + } + str = string([]byte(str)[:m]) + return str +} diff --git a/drivers/teambition/meta.go b/drivers/teambition/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..45a6a47286cd6110436fef9be40a2c4258752971 --- /dev/null +++ b/drivers/teambition/meta.go @@ -0,0 +1,26 @@ +package teambition + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + Region string `json:"region" type:"select" options:"china,international" required:"true"` + Cookie string `json:"cookie" required:"true"` + ProjectID string `json:"project_id" required:"true"` + driver.RootID + OrderBy string `json:"order_by" type:"select" options:"fileName,fileSize,updated,created" default:"fileName"` + OrderDirection string `json:"order_direction" type:"select" options:"Asc,Desc" default:"Asc"` + UseS3UploadMethod bool `json:"use_s3_upload_method" default:"true"` +} + +var config = driver.Config{ + Name: "Teambition", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Teambition{} + }) +} diff --git a/drivers/teambition/types.go b/drivers/teambition/types.go new file mode 100644 index 0000000000000000000000000000000000000000..eada9a6d0724a6491af8fc57f93c339241869cd5 --- /dev/null +++ b/drivers/teambition/types.go @@ -0,0 +1,89 @@ +package teambition + +import "time" + +type ErrResp struct { + Name string `json:"name"` + Message string `json:"message"` +} + +type Collection struct { + ID string `json:"_id"` + Title string `json:"title"` + Updated time.Time `json:"updated"` +} + +type Work struct { + ID string `json:"_id"` + FileName string `json:"fileName"` + FileSize int64 `json:"fileSize"` + FileKey string `json:"fileKey"` + FileCategory string `json:"fileCategory"` + DownloadURL string `json:"downloadUrl"` + ThumbnailURL string `json:"thumbnailUrl"` + Thumbnail string `json:"thumbnail"` + Updated time.Time `json:"updated"` + PreviewURL string `json:"previewUrl"` +} + +type FileUpload struct { + FileKey string `json:"fileKey"` + FileName string `json:"fileName"` + FileType string `json:"fileType"` + FileSize int `json:"fileSize"` + FileCategory string `json:"fileCategory"` + ImageWidth int `json:"imageWidth"` + ImageHeight int `json:"imageHeight"` + InvolveMembers []interface{} `json:"involveMembers"` + Source string `json:"source"` + Visible string `json:"visible"` + ParentId string `json:"_parentId"` +} + +type ChunkUpload struct { + FileUpload + Storage string `json:"storage"` + MimeType string `json:"mimeType"` + Chunks int `json:"chunks"` + ChunkSize int `json:"chunkSize"` + Created time.Time `json:"created"` + FileMD5 string `json:"fileMD5"` + LastUpdated time.Time `json:"lastUpdated"` + UploadedChunks []interface{} `json:"uploadedChunks"` + Token struct { + AppID string `json:"AppID"` + OrganizationID string `json:"OrganizationID"` + UserID string `json:"UserID"` + Exp time.Time `json:"Exp"` + Storage string `json:"Storage"` + Resource string `json:"Resource"` + Speed int `json:"Speed"` + } `json:"token"` + DownloadUrl string `json:"downloadUrl"` + ThumbnailUrl string `json:"thumbnailUrl"` + PreviewUrl string `json:"previewUrl"` + ImmPreviewUrl string `json:"immPreviewUrl"` + PreviewExt string `json:"previewExt"` + LastUploadTime interface{} `json:"lastUploadTime"` +} + +type UploadToken struct { + Sdk struct { + Endpoint string `json:"endpoint"` + Region string `json:"region"` + S3ForcePathStyle bool `json:"s3ForcePathStyle"` + Credentials struct { + AccessKeyId string `json:"accessKeyId"` + SecretAccessKey string `json:"secretAccessKey"` + SessionToken string `json:"sessionToken"` + } `json:"credentials"` + } `json:"sdk"` + Upload struct { + Bucket string `json:"Bucket"` + Key string `json:"Key"` + ContentDisposition string `json:"ContentDisposition"` + ContentType string `json:"ContentType"` + } `json:"upload"` + Token string `json:"token"` + DownloadUrl string `json:"downloadUrl"` +} diff --git a/drivers/teambition/util.go b/drivers/teambition/util.go new file mode 100644 index 0000000000000000000000000000000000000000..181cc58f64ca6f06f6d45ea6baab01610698380e --- /dev/null +++ b/drivers/teambition/util.go @@ -0,0 +1,272 @@ +package teambition + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +// do others that not defined in Driver interface + +func (d *Teambition) isInternational() bool { + return d.Region == "international" +} + +func (d *Teambition) request(pathname string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + url := "https://www.teambition.com" + pathname + if d.isInternational() { + url = "https://us.teambition.com" + pathname + } + req := base.RestyClient.R() + req.SetHeader("Cookie", d.Cookie) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e ErrResp + req.SetError(&e) + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + if e.Name != "" { + return nil, errors.New(e.Message) + } + return res.Body(), nil +} + +func (d *Teambition) getFiles(parentId string) ([]model.Obj, error) { + files := make([]model.Obj, 0) + page := 1 + for { + var collections []Collection + _, err := d.request("/api/collections", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "_parentId": parentId, + "_projectId": d.ProjectID, + "order": d.OrderBy + d.OrderDirection, + "count": "50", + "page": strconv.Itoa(page), + }) + }, &collections) + if err != nil { + return nil, err + } + if len(collections) == 0 { + break + } + page++ + for _, collection := range collections { + if collection.Title == "" { + continue + } + files = append(files, &model.Object{ + ID: collection.ID, + Name: collection.Title, + IsFolder: true, + Modified: collection.Updated, + }) + } + } + page = 1 + for { + var works []Work + _, err := d.request("/api/works", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "_parentId": parentId, + "_projectId": d.ProjectID, + "order": d.OrderBy + d.OrderDirection, + "count": "50", + "page": strconv.Itoa(page), + }) + }, &works) + if err != nil { + return nil, err + } + if len(works) == 0 { + break + } + page++ + for _, work := range works { + files = append(files, &model.ObjThumbURL{ + Object: model.Object{ + ID: work.ID, + Name: work.FileName, + Size: work.FileSize, + Modified: work.Updated, + }, + Thumbnail: model.Thumbnail{Thumbnail: work.Thumbnail}, + Url: model.Url{Url: work.DownloadURL}, + }) + } + } + return files, nil +} + +func (d *Teambition) upload(ctx context.Context, file model.FileStreamer, token string) (*FileUpload, error) { + prefix := "tcs" + if d.isInternational() { + prefix = "us-tcs" + } + var newFile FileUpload + res, err := base.RestyClient.R(). + SetContext(ctx). + SetResult(&newFile).SetHeader("Authorization", token). + SetMultipartFormData(map[string]string{ + "name": file.GetName(), + "type": file.GetMimetype(), + "size": strconv.FormatInt(file.GetSize(), 10), + "lastModifiedDate": time.Now().Format("Mon Jan 02 2006 15:04:05 GMT+0800 (中国标准时间)"), + }).SetMultipartField("file", file.GetName(), file.GetMimetype(), file). + Post(fmt.Sprintf("https://%s.teambition.net/upload", prefix)) + if err != nil { + return nil, err + } + log.Debugf("[teambition] upload response: %s", res.String()) + return &newFile, nil +} + +func (d *Teambition) chunkUpload(ctx context.Context, file model.FileStreamer, token string, up driver.UpdateProgress) (*FileUpload, error) { + prefix := "tcs" + referer := "https://www.teambition.com/" + if d.isInternational() { + prefix = "us-tcs" + referer = "https://us.teambition.com/" + } + var newChunk ChunkUpload + _, err := base.RestyClient.R().SetResult(&newChunk).SetHeader("Authorization", token). + SetBody(base.Json{ + "fileName": file.GetName(), + "fileSize": file.GetSize(), + "lastUpdated": time.Now(), + }).Post(fmt.Sprintf("https://%s.teambition.net/upload/chunk", prefix)) + if err != nil { + return nil, err + } + for i := 0; i < newChunk.Chunks; i++ { + if utils.IsCanceled(ctx) { + return nil, ctx.Err() + } + chunkSize := newChunk.ChunkSize + if i == newChunk.Chunks-1 { + chunkSize = int(file.GetSize()) - i*chunkSize + } + log.Debugf("%d : %d", i, chunkSize) + chunkData := make([]byte, chunkSize) + _, err = io.ReadFull(file, chunkData) + if err != nil { + return nil, err + } + u := fmt.Sprintf("https://%s.teambition.net/upload/chunk/%s?chunk=%d&chunks=%d", + prefix, newChunk.FileKey, i+1, newChunk.Chunks) + log.Debugf("url: %s", u) + _, err := base.RestyClient.R(). + SetContext(ctx). + SetHeaders(map[string]string{ + "Authorization": token, + "Content-Type": "application/octet-stream", + "Referer": referer, + }).SetBody(chunkData).Post(u) + if err != nil { + return nil, err + } + if err != nil { + return nil, err + } + up(float64(i) * 100 / float64(newChunk.Chunks)) + } + _, err = base.RestyClient.R().SetHeader("Authorization", token).Post( + fmt.Sprintf("https://%s.teambition.net/upload/chunk/%s", + prefix, newChunk.FileKey)) + if err != nil { + return nil, err + } + return &newChunk.FileUpload, nil +} + +func (d *Teambition) finishUpload(file *FileUpload, parentId string) error { + file.InvolveMembers = []interface{}{} + file.Visible = "members" + file.ParentId = parentId + _, err := d.request("/api/works", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "works": []FileUpload{*file}, + "_parentId": parentId, + }) + }, nil) + return err +} + +func (d *Teambition) newUpload(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + var uploadToken UploadToken + _, err := d.request("/api/awos/upload-token", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "category": "work", + "fileName": stream.GetName(), + "fileSize": stream.GetSize(), + "fileType": stream.GetMimetype(), + "payload": base.Json{ + "involveMembers": []struct{}{}, + "visible": "members", + }, + "scope": "project:" + d.ProjectID, + }) + }, &uploadToken) + if err != nil { + return err + } + cfg := &aws.Config{ + Credentials: credentials.NewStaticCredentials( + uploadToken.Sdk.Credentials.AccessKeyId, uploadToken.Sdk.Credentials.SecretAccessKey, uploadToken.Sdk.Credentials.SessionToken), + Region: &uploadToken.Sdk.Region, + Endpoint: &uploadToken.Sdk.Endpoint, + S3ForcePathStyle: &uploadToken.Sdk.S3ForcePathStyle, + } + ss, err := session.NewSession(cfg) + if err != nil { + return err + } + uploader := s3manager.NewUploader(ss) + if stream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { + uploader.PartSize = stream.GetSize() / (s3manager.MaxUploadParts - 1) + } + input := &s3manager.UploadInput{ + Bucket: &uploadToken.Upload.Bucket, + Key: &uploadToken.Upload.Key, + ContentDisposition: &uploadToken.Upload.ContentDisposition, + ContentType: &uploadToken.Upload.ContentType, + Body: stream, + } + _, err = uploader.UploadWithContext(ctx, input) + if err != nil { + return err + } + // finish upload + _, err = d.request("/api/works", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "fileTokens": []string{uploadToken.Token}, + "involveMembers": []struct{}{}, + "visible": "members", + "works": []struct{}{}, + "_parentId": dstDir.GetID(), + }) + }, nil) + return err +} diff --git a/drivers/template/driver.go b/drivers/template/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..439f57f35f9913c68a855a3d67d6d11ca915068c --- /dev/null +++ b/drivers/template/driver.go @@ -0,0 +1,78 @@ +package template + +import ( + "context" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" +) + +type Template struct { + model.Storage + Addition +} + +func (d *Template) Config() driver.Config { + return config +} + +func (d *Template) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Template) Init(ctx context.Context) error { + // TODO login / refresh token + //op.MustSaveDriverStorage(d) + return nil +} + +func (d *Template) Drop(ctx context.Context) error { + return nil +} + +func (d *Template) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + // TODO return the files list, required + return nil, errs.NotImplement +} + +func (d *Template) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + // TODO return link of file, required + return nil, errs.NotImplement +} + +func (d *Template) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + // TODO create folder, optional + return nil, errs.NotImplement +} + +func (d *Template) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + // TODO move obj, optional + return nil, errs.NotImplement +} + +func (d *Template) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + // TODO rename obj, optional + return nil, errs.NotImplement +} + +func (d *Template) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + // TODO copy obj, optional + return nil, errs.NotImplement +} + +func (d *Template) Remove(ctx context.Context, obj model.Obj) error { + // TODO remove obj, optional + return errs.NotImplement +} + +func (d *Template) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + // TODO upload file, optional + return nil, errs.NotImplement +} + +//func (d *Template) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*Template)(nil) diff --git a/drivers/template/meta.go b/drivers/template/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..45aeb85ac58e5072d86153d67035b486263601b6 --- /dev/null +++ b/drivers/template/meta.go @@ -0,0 +1,34 @@ +package template + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + // Usually one of two + driver.RootPath + driver.RootID + // define other + Field string `json:"field" type:"select" required:"true" options:"a,b,c" default:"a"` +} + +var config = driver.Config{ + Name: "Template", + LocalSort: false, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: false, + NeedMs: false, + DefaultRoot: "root, / or other", + CheckStatus: false, + Alert: "", + NoOverwriteUpload: false, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Template{} + }) +} diff --git a/drivers/template/types.go b/drivers/template/types.go new file mode 100644 index 0000000000000000000000000000000000000000..38cdfe4490d1752f31bad3921f011a0bfefdef68 --- /dev/null +++ b/drivers/template/types.go @@ -0,0 +1 @@ +package template diff --git a/drivers/template/util.go b/drivers/template/util.go new file mode 100644 index 0000000000000000000000000000000000000000..9d967bdfd11f5667a92e414c2f1c5d36309858c1 --- /dev/null +++ b/drivers/template/util.go @@ -0,0 +1,3 @@ +package template + +// do others that not defined in Driver interface diff --git a/drivers/terabox/driver.go b/drivers/terabox/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..362de69e0a0aad30b335662be1a6c421c6cf2a62 --- /dev/null +++ b/drivers/terabox/driver.go @@ -0,0 +1,273 @@ +package terabox + +import ( + "bytes" + "context" + "crypto/md5" + "encoding/hex" + "fmt" + "io" + "math" + stdpath "path" + "strconv" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/pkg/utils" + log "github.com/sirupsen/logrus" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" +) + +type Terabox struct { + model.Storage + Addition + JsToken string + url_domain_prefix string + base_url string +} + +func (d *Terabox) Config() driver.Config { + return config +} + +func (d *Terabox) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Terabox) Init(ctx context.Context) error { + var resp CheckLoginResp + d.base_url = "https://www.terabox.com" + d.url_domain_prefix = "jp" + _, err := d.get("/api/check/login", nil, &resp) + if err != nil { + return err + } + if resp.Errno != 0 { + if resp.Errno == 9000 { + return fmt.Errorf("terabox is not yet available in this area") + } + return fmt.Errorf("failed to check login status according to cookie") + } + return err +} + +func (d *Terabox) Drop(ctx context.Context) error { + return nil +} + +func (d *Terabox) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(dir.GetPath()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *Terabox) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if d.DownloadAPI == "crack" { + return d.linkCrack(file, args) + } + return d.linkOfficial(file, args) +} + +func (d *Terabox) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + params := map[string]string{ + "a": "commit", + } + data := map[string]string{ + "path": stdpath.Join(parentDir.GetPath(), dirName), + "isdir": "1", + "block_list": "[]", + } + res, err := d.post_form("/api/create", params, data, nil) + log.Debugln(string(res)) + return err +} + +func (d *Terabox) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + data := []base.Json{ + { + "path": srcObj.GetPath(), + "dest": dstDir.GetPath(), + "newname": srcObj.GetName(), + }, + } + _, err := d.manage("move", data) + return err +} + +func (d *Terabox) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + data := []base.Json{ + { + "path": srcObj.GetPath(), + "newname": newName, + }, + } + _, err := d.manage("rename", data) + return err +} + +func (d *Terabox) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + data := []base.Json{ + { + "path": srcObj.GetPath(), + "dest": dstDir.GetPath(), + "newname": srcObj.GetName(), + }, + } + _, err := d.manage("copy", data) + return err +} + +func (d *Terabox) Remove(ctx context.Context, obj model.Obj) error { + data := []string{obj.GetPath()} + _, err := d.manage("delete", data) + return err +} + +func (d *Terabox) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + resp, err := base.RestyClient.R(). + SetContext(ctx). + Get("https://" + d.url_domain_prefix + "-data.terabox.com/rest/2.0/pcs/file?method=locateupload") + if err != nil { + return err + } + var locateupload_resp LocateUploadResp + err = utils.Json.Unmarshal(resp.Body(), &locateupload_resp) + if err != nil { + log.Debugln(resp) + return err + } + log.Debugln(locateupload_resp) + + // precreate file + rawPath := stdpath.Join(dstDir.GetPath(), stream.GetName()) + path := encodeURIComponent(rawPath) + + var precreateBlockListStr string + if stream.GetSize() > initialChunkSize { + precreateBlockListStr = `["5910a591dd8fc18c32a8f3df4fdc1761","a5fc157d78e6ad1c7e114b056c92821e"]` + } else { + precreateBlockListStr = `["5910a591dd8fc18c32a8f3df4fdc1761"]` + } + + data := map[string]string{ + "path": rawPath, + "autoinit": "1", + "target_path": dstDir.GetPath(), + "block_list": precreateBlockListStr, + "local_mtime": strconv.FormatInt(stream.ModTime().Unix(), 10), + "file_limit_switch_v34": "true", + } + var precreateResp PrecreateResp + log.Debugln(data) + res, err := d.post_form("/api/precreate", nil, data, &precreateResp) + if err != nil { + return err + } + log.Debugf("%+v", precreateResp) + if precreateResp.Errno != 0 { + log.Debugln(string(res)) + return fmt.Errorf("[terabox] failed to precreate file, errno: %d", precreateResp.Errno) + } + if precreateResp.ReturnType == 2 { + return nil + } + + // upload chunks + tempFile, err := stream.CacheFullInTempFile() + if err != nil { + return err + } + + params := map[string]string{ + "method": "upload", + "path": path, + "uploadid": precreateResp.Uploadid, + "app_id": "250528", + "web": "1", + "channel": "dubox", + "clienttype": "0", + } + + streamSize := stream.GetSize() + chunkSize := calculateChunkSize(streamSize) + chunkByteData := make([]byte, chunkSize) + count := int(math.Ceil(float64(streamSize) / float64(chunkSize))) + left := streamSize + uploadBlockList := make([]string, 0, count) + h := md5.New() + for partseq := 0; partseq < count; partseq++ { + if utils.IsCanceled(ctx) { + return ctx.Err() + } + byteSize := chunkSize + var byteData []byte + if left >= chunkSize { + byteData = chunkByteData + } else { + byteSize = left + byteData = make([]byte, byteSize) + } + left -= byteSize + _, err = io.ReadFull(tempFile, byteData) + if err != nil { + return err + } + + // calculate md5 + h.Write(byteData) + uploadBlockList = append(uploadBlockList, hex.EncodeToString(h.Sum(nil))) + h.Reset() + + u := "https://" + locateupload_resp.Host + "/rest/2.0/pcs/superfile2" + params["partseq"] = strconv.Itoa(partseq) + res, err := base.RestyClient.R(). + SetContext(ctx). + SetQueryParams(params). + SetFileReader("file", stream.GetName(), bytes.NewReader(byteData)). + SetHeader("Cookie", d.Cookie). + Post(u) + if err != nil { + return err + } + log.Debugln(res.String()) + if count > 0 { + up(float64(partseq) * 100 / float64(count)) + } + } + + // create file + params = map[string]string{ + "isdir": "0", + "rtype": "1", + } + + uploadBlockListStr, err := utils.Json.MarshalToString(uploadBlockList) + if err != nil { + return err + } + data = map[string]string{ + "path": rawPath, + "size": strconv.FormatInt(stream.GetSize(), 10), + "uploadid": precreateResp.Uploadid, + "target_path": dstDir.GetPath(), + "block_list": uploadBlockListStr, + "local_mtime": strconv.FormatInt(stream.ModTime().Unix(), 10), + } + var createResp CreateResp + res, err = d.post_form("/api/create", params, data, &createResp) + log.Debugln(string(res)) + if err != nil { + return err + } + if createResp.Errno != 0 { + return fmt.Errorf("[terabox] failed to create file, errno: %d", createResp.Errno) + } + return nil +} + +var _ driver.Driver = (*Terabox)(nil) diff --git a/drivers/terabox/meta.go b/drivers/terabox/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..63ae585647179402456f8b713b4fa242ce11fc78 --- /dev/null +++ b/drivers/terabox/meta.go @@ -0,0 +1,26 @@ +package terabox + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + Cookie string `json:"cookie" required:"true"` + //JsToken string `json:"js_token" type:"string" required:"true"` + DownloadAPI string `json:"download_api" type:"select" options:"official,crack" default:"official"` + OrderBy string `json:"order_by" type:"select" options:"name,time,size" default:"name"` + OrderDirection string `json:"order_direction" type:"select" options:"asc,desc" default:"asc"` +} + +var config = driver.Config{ + Name: "Terabox", + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Terabox{} + }) +} diff --git a/drivers/terabox/types.go b/drivers/terabox/types.go new file mode 100644 index 0000000000000000000000000000000000000000..f4d50ddef374f4f1bc9e5571f62a67059c063aa2 --- /dev/null +++ b/drivers/terabox/types.go @@ -0,0 +1,105 @@ +package terabox + +import ( + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type File struct { + //TkbindId int `json:"tkbind_id"` + //OwnerType int `json:"owner_type"` + //Category int `json:"category"` + //RealCategory string `json:"real_category"` + FsId int64 `json:"fs_id"` + ServerMtime int64 `json:"server_mtime"` + //OperId int `json:"oper_id"` + //ServerCtime int `json:"server_ctime"` + Thumbs struct { + //Icon string `json:"icon"` + Url3 string `json:"url3"` + //Url2 string `json:"url2"` + //Url1 string `json:"url1"` + } `json:"thumbs"` + //Wpfile int `json:"wpfile"` + //LocalMtime int `json:"local_mtime"` + Size int64 `json:"size"` + //ExtentTinyint7 int `json:"extent_tinyint7"` + Path string `json:"path"` + //Share int `json:"share"` + //ServerAtime int `json:"server_atime"` + //Pl int `json:"pl"` + //LocalCtime int `json:"local_ctime"` + ServerFilename string `json:"server_filename"` + //Md5 string `json:"md5"` + //OwnerId int `json:"owner_id"` + //Unlist int `json:"unlist"` + Isdir int `json:"isdir"` +} + +type ListResp struct { + Errno int `json:"errno"` + GuidInfo string `json:"guid_info"` + List []File `json:"list"` + //RequestId int64 `json:"request_id"` 接口返回有时是int有时是string + Guid int `json:"guid"` +} + +func fileToObj(f File) *model.ObjThumb { + return &model.ObjThumb{ + Object: model.Object{ + ID: strconv.FormatInt(f.FsId, 10), + Name: f.ServerFilename, + Size: f.Size, + Modified: time.Unix(f.ServerMtime, 0), + IsFolder: f.Isdir == 1, + }, + Thumbnail: model.Thumbnail{Thumbnail: f.Thumbs.Url3}, + } +} + +type DownloadResp struct { + Errno int `json:"errno"` + Dlink []struct { + Dlink string `json:"dlink"` + } `json:"dlink"` +} + +type DownloadResp2 struct { + Errno int `json:"errno"` + Info []struct { + Dlink string `json:"dlink"` + } `json:"info"` + //RequestID int64 `json:"request_id"` +} + +type HomeInfoResp struct { + Errno int `json:"errno"` + Data struct { + Sign1 string `json:"sign1"` + Sign3 string `json:"sign3"` + Timestamp int `json:"timestamp"` + } `json:"data"` +} + +type PrecreateResp struct { + Path string `json:"path"` + Uploadid string `json:"uploadid"` + ReturnType int `json:"return_type"` + BlockList []int `json:"block_list"` + Errno int `json:"errno"` + //RequestId int64 `json:"request_id"` +} + +type CheckLoginResp struct { + Errno int `json:"errno"` +} + +type LocateUploadResp struct { + Host string `json:"host"` +} + +type CreateResp struct { + Errno int `json:"errno"` +} diff --git a/drivers/terabox/util.go b/drivers/terabox/util.go new file mode 100644 index 0000000000000000000000000000000000000000..058eecd6085bf20024ea153b5fd7a1a7f3f490ba --- /dev/null +++ b/drivers/terabox/util.go @@ -0,0 +1,285 @@ +package terabox + +import ( + "encoding/base64" + "fmt" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + log "github.com/sirupsen/logrus" +) + +const ( + initialChunkSize int64 = 4 << 20 // 4MB + initialSizeThreshold int64 = 4 << 30 // 4GB +) + +func getStrBetween(raw, start, end string) string { + regexPattern := fmt.Sprintf(`%s(.*?)%s`, regexp.QuoteMeta(start), regexp.QuoteMeta(end)) + regex := regexp.MustCompile(regexPattern) + matches := regex.FindStringSubmatch(raw) + if len(matches) < 2 { + return "" + } + mid := matches[1] + return mid +} + +func (d *Terabox) resetJsToken() error { + u := d.base_url + res, err := base.RestyClient.R().SetHeaders(map[string]string{ + "Cookie": d.Cookie, + "Accept": "application/json, text/plain, */*", + "Referer": d.base_url, + "User-Agent": base.UserAgent, + "X-Requested-With": "XMLHttpRequest", + }).Get(u) + if err != nil { + return err + } + html := res.String() + jsToken := getStrBetween(html, "`function%20fn%28a%29%7Bwindow.jsToken%20%3D%20a%7D%3Bfn%28%22", "%22%29`") + if jsToken == "" { + return fmt.Errorf("jsToken not found, html: %s", html) + } + d.JsToken = jsToken + return nil +} + +func (d *Terabox) request(rurl string, method string, callback base.ReqCallback, resp interface{}, noRetry ...bool) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "Cookie": d.Cookie, + "Accept": "application/json, text/plain, */*", + "Referer": d.base_url, + "User-Agent": base.UserAgent, + "X-Requested-With": "XMLHttpRequest", + }) + req.SetQueryParams(map[string]string{ + "app_id": "250528", + "web": "1", + "channel": "dubox", + "clienttype": "0", + "jsToken": d.JsToken, + }) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + res, err := req.Execute(method, d.base_url+rurl) + if err != nil { + return nil, err + } + errno := utils.Json.Get(res.Body(), "errno").ToInt() + if errno == 4000023 { + // reget jsToken + err = d.resetJsToken() + if err != nil { + return nil, err + } + if !utils.IsBool(noRetry...) { + return d.request(rurl, method, callback, resp, true) + } + } else if errno == -6 { + header := res.Header() + log.Debugln(header) + urlDomainPrefix := header.Get("Url-Domain-Prefix") + if len(urlDomainPrefix) > 0 { + d.url_domain_prefix = urlDomainPrefix + d.base_url = "https://" + d.url_domain_prefix + ".terabox.com" + log.Debugln("Redirect base_url to", d.base_url) + return d.request(rurl, method, callback, resp, noRetry...) + } + } + return res.Body(), nil +} + +func (d *Terabox) get(pathname string, params map[string]string, resp interface{}) ([]byte, error) { + return d.request(pathname, http.MethodGet, func(req *resty.Request) { + if params != nil { + req.SetQueryParams(params) + } + }, resp) +} + +func (d *Terabox) post(pathname string, params map[string]string, data interface{}, resp interface{}) ([]byte, error) { + return d.request(pathname, http.MethodPost, func(req *resty.Request) { + if params != nil { + req.SetQueryParams(params) + } + req.SetBody(data) + }, resp) +} + +func (d *Terabox) post_form(pathname string, params map[string]string, data map[string]string, resp interface{}) ([]byte, error) { + return d.request(pathname, http.MethodPost, func(req *resty.Request) { + if params != nil { + req.SetQueryParams(params) + } + req.SetFormData(data) + }, resp) +} + +func (d *Terabox) getFiles(dir string) ([]File, error) { + page := 1 + num := 100 + params := map[string]string{ + "dir": dir, + } + if d.OrderBy != "" { + params["order"] = d.OrderBy + if d.OrderDirection == "desc" { + params["desc"] = "1" + } + } + res := make([]File, 0) + for { + params["page"] = strconv.Itoa(page) + params["num"] = strconv.Itoa(num) + var resp ListResp + _, err := d.get("/api/list", params, &resp) + if err != nil { + return nil, err + } + if resp.Errno == 9000 { + return nil, fmt.Errorf("terabox is not yet available in this area") + } + if len(resp.List) == 0 { + break + } + res = append(res, resp.List...) + page++ + } + return res, nil +} + +func sign(s1, s2 string) string { + var a = make([]int, 256) + var p = make([]int, 256) + var o []byte + var v = len(s1) + for q := 0; q < 256; q++ { + a[q] = int(s1[(q % v) : (q%v)+1][0]) + p[q] = q + } + for u, q := 0, 0; q < 256; q++ { + u = (u + p[q] + a[q]) % 256 + p[q], p[u] = p[u], p[q] + } + for i, u, q := 0, 0, 0; q < len(s2); q++ { + i = (i + 1) % 256 + u = (u + p[i]) % 256 + p[i], p[u] = p[u], p[i] + k := p[((p[i] + p[u]) % 256)] + o = append(o, byte(int(s2[q])^k)) + } + return base64.StdEncoding.EncodeToString(o) +} + +func (d *Terabox) genSign() (string, error) { + var resp HomeInfoResp + _, err := d.get("/api/home/info", map[string]string{}, &resp) + if err != nil { + return "", err + } + return sign(resp.Data.Sign3, resp.Data.Sign1), nil +} + +func (d *Terabox) linkOfficial(file model.Obj, args model.LinkArgs) (*model.Link, error) { + var resp DownloadResp + signString, err := d.genSign() + if err != nil { + return nil, err + } + params := map[string]string{ + "type": "dlink", + "fidlist": fmt.Sprintf("[%s]", file.GetID()), + "sign": signString, + "vip": "2", + "timestamp": strconv.FormatInt(time.Now().Unix(), 10), + } + _, err = d.get("/api/download", params, &resp) + if err != nil { + return nil, err + } + + if len(resp.Dlink) == 0 { + return nil, fmt.Errorf("fid %s no dlink found, errno: %d", file.GetID(), resp.Errno) + } + + res, err := base.NoRedirectClient.R().SetHeader("Cookie", d.Cookie).SetHeader("User-Agent", base.UserAgent).Get(resp.Dlink[0].Dlink) + if err != nil { + return nil, err + } + u := res.Header().Get("location") + return &model.Link{ + URL: u, + Header: http.Header{ + "User-Agent": []string{base.UserAgent}, + }, + }, nil +} + +func (d *Terabox) linkCrack(file model.Obj, args model.LinkArgs) (*model.Link, error) { + var resp DownloadResp2 + param := map[string]string{ + "target": fmt.Sprintf("[\"%s\"]", file.GetPath()), + "dlink": "1", + "origin": "dlna", + } + _, err := d.get("/api/filemetas", param, &resp) + if err != nil { + return nil, err + } + return &model.Link{ + URL: resp.Info[0].Dlink, + Header: http.Header{ + "User-Agent": []string{base.UserAgent}, + }, + }, nil +} + +func (d *Terabox) manage(opera string, filelist interface{}) ([]byte, error) { + params := map[string]string{ + "onnest": "fail", + "opera": opera, + } + marshal, err := utils.Json.Marshal(filelist) + if err != nil { + return nil, err + } + data := fmt.Sprintf("async=0&filelist=%s&ondup=newcopy", encodeURIComponent(string(marshal))) + return d.post("/api/filemanager", params, data, nil) +} + +func encodeURIComponent(str string) string { + r := url.QueryEscape(str) + r = strings.ReplaceAll(r, "+", "%20") + return r +} + +func calculateChunkSize(streamSize int64) int64 { + chunkSize := initialChunkSize + sizeThreshold := initialSizeThreshold + + if streamSize < chunkSize { + return streamSize + } + + for streamSize > sizeThreshold { + chunkSize <<= 1 + sizeThreshold <<= 1 + } + + return chunkSize +} diff --git a/drivers/thunder/driver.go b/drivers/thunder/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..d6502e9b03373ec4a7f7ebbf9ac12417888a4ab8 --- /dev/null +++ b/drivers/thunder/driver.go @@ -0,0 +1,553 @@ +package thunder + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/go-resty/resty/v2" +) + +type Thunder struct { + *XunLeiCommon + model.Storage + Addition + + identity string +} + +func (x *Thunder) Config() driver.Config { + return config +} + +func (x *Thunder) GetAddition() driver.Additional { + return &x.Addition +} + +func (x *Thunder) Init(ctx context.Context) (err error) { + // 初始化所需参数 + if x.XunLeiCommon == nil { + x.XunLeiCommon = &XunLeiCommon{ + Common: &Common{ + client: base.NewRestyClient(), + Algorithms: []string{ + "HPxr4BVygTQVtQkIMwQH33ywbgYG5l4JoR", + "GzhNkZ8pOBsCY+7", + "v+l0ImTpG7c7/", + "e5ztohgVXNP", + "t", + "EbXUWyVVqQbQX39Mbjn2geok3/0WEkAVxeqhtx857++kjJiRheP8l77gO", + "o7dvYgbRMOpHXxCs", + "6MW8TD8DphmakaxCqVrfv7NReRRN7ck3KLnXBculD58MvxjFRqT+", + "kmo0HxCKVfmxoZswLB4bVA/dwqbVAYghSb", + "j", + "4scKJNdd7F27Hv7tbt", + }, + DeviceID: utils.GetMD5EncodeStr(x.Username + x.Password), + ClientID: "Xp6vsxz_7IYVw2BB", + ClientSecret: "Xp6vsy4tN9toTVdMSpomVdXpRmES", + ClientVersion: "7.51.0.8196", + PackageName: "com.xunlei.downloadprovider", + UserAgent: "ANDROID-com.xunlei.downloadprovider/7.51.0.8196 netWorkType/5G appid/40 deviceName/Xiaomi_M2004j7ac deviceModel/M2004J7AC OSVersion/12 protocolVersion/301 platformVersion/10 sdkVersion/220200 Oauth2Client/0.9 (Linux 4_14_186-perf-gddfs8vbb238b) (JAVA 0)", + DownloadUserAgent: "Dalvik/2.1.0 (Linux; U; Android 12; M2004J7AC Build/SP1A.210812.016)", + + refreshCTokenCk: func(token string) { + x.CaptchaToken = token + op.MustSaveDriverStorage(x) + }, + }, + refreshTokenFunc: func() error { + // 通过RefreshToken刷新 + token, err := x.RefreshToken(x.TokenResp.RefreshToken) + if err != nil { + // 重新登录 + token, err = x.Login(x.Username, x.Password) + if err != nil { + x.GetStorage().SetStatus(fmt.Sprintf("%+v", err.Error())) + op.MustSaveDriverStorage(x) + } + } + x.SetTokenResp(token) + return err + }, + } + } + + // 自定义验证码token + ctoekn := strings.TrimSpace(x.CaptchaToken) + if ctoekn != "" { + x.SetCaptchaToken(ctoekn) + } + + // 防止重复登录 + identity := x.GetIdentity() + if x.identity != identity || !x.IsLogin() { + x.identity = identity + // 登录 + token, err := x.Login(x.Username, x.Password) + if err != nil { + return err + } + x.SetTokenResp(token) + } + return nil +} + +func (x *Thunder) Drop(ctx context.Context) error { + return nil +} + +type ThunderExpert struct { + *XunLeiCommon + model.Storage + ExpertAddition + + identity string +} + +func (x *ThunderExpert) Config() driver.Config { + return configExpert +} + +func (x *ThunderExpert) GetAddition() driver.Additional { + return &x.ExpertAddition +} + +func (x *ThunderExpert) Init(ctx context.Context) (err error) { + // 防止重复登录 + identity := x.GetIdentity() + if identity != x.identity || !x.IsLogin() { + x.identity = identity + x.XunLeiCommon = &XunLeiCommon{ + Common: &Common{ + client: base.NewRestyClient(), + + DeviceID: func() string { + if len(x.DeviceID) != 32 { + return utils.GetMD5EncodeStr(x.DeviceID) + } + return x.DeviceID + }(), + ClientID: x.ClientID, + ClientSecret: x.ClientSecret, + ClientVersion: x.ClientVersion, + PackageName: x.PackageName, + UserAgent: x.UserAgent, + DownloadUserAgent: x.DownloadUserAgent, + UseVideoUrl: x.UseVideoUrl, + + refreshCTokenCk: func(token string) { + x.CaptchaToken = token + op.MustSaveDriverStorage(x) + }, + }, + } + + if x.CaptchaToken != "" { + x.SetCaptchaToken(x.CaptchaToken) + } + + // 签名方法 + if x.SignType == "captcha_sign" { + x.Common.Timestamp = x.Timestamp + x.Common.CaptchaSign = x.CaptchaSign + } else { + x.Common.Algorithms = strings.Split(x.Algorithms, ",") + } + + // 登录方式 + if x.LoginType == "refresh_token" { + // 通过RefreshToken登录 + token, err := x.XunLeiCommon.RefreshToken(x.ExpertAddition.RefreshToken) + if err != nil { + return err + } + x.SetTokenResp(token) + + // 刷新token方法 + x.SetRefreshTokenFunc(func() error { + token, err := x.XunLeiCommon.RefreshToken(x.TokenResp.RefreshToken) + if err != nil { + x.GetStorage().SetStatus(fmt.Sprintf("%+v", err.Error())) + } + x.SetTokenResp(token) + op.MustSaveDriverStorage(x) + return err + }) + } else { + // 通过用户密码登录 + token, err := x.Login(x.Username, x.Password) + if err != nil { + return err + } + x.SetTokenResp(token) + x.SetRefreshTokenFunc(func() error { + token, err := x.XunLeiCommon.RefreshToken(x.TokenResp.RefreshToken) + if err != nil { + token, err = x.Login(x.Username, x.Password) + if err != nil { + x.GetStorage().SetStatus(fmt.Sprintf("%+v", err.Error())) + } + } + x.SetTokenResp(token) + op.MustSaveDriverStorage(x) + return err + }) + } + } else { + // 仅修改验证码token + if x.CaptchaToken != "" { + x.SetCaptchaToken(x.CaptchaToken) + } + x.XunLeiCommon.UserAgent = x.UserAgent + x.XunLeiCommon.DownloadUserAgent = x.DownloadUserAgent + x.XunLeiCommon.UseVideoUrl = x.UseVideoUrl + } + return nil +} + +func (x *ThunderExpert) Drop(ctx context.Context) error { + return nil +} + +func (x *ThunderExpert) SetTokenResp(token *TokenResp) { + x.XunLeiCommon.SetTokenResp(token) + if token != nil { + x.ExpertAddition.RefreshToken = token.RefreshToken + } +} + +type XunLeiCommon struct { + *Common + *TokenResp // 登录信息 + + refreshTokenFunc func() error +} + +func (xc *XunLeiCommon) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + return xc.getFiles(ctx, dir.GetID()) +} + +func (xc *XunLeiCommon) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var lFile Files + _, err := xc.Request(FILE_API_URL+"/{fileID}", http.MethodGet, func(r *resty.Request) { + r.SetContext(ctx) + r.SetPathParam("fileID", file.GetID()) + //r.SetQueryParam("space", "") + }, &lFile) + if err != nil { + return nil, err + } + link := &model.Link{ + URL: lFile.WebContentLink, + Header: http.Header{ + "User-Agent": {xc.DownloadUserAgent}, + }, + } + + if xc.UseVideoUrl { + for _, media := range lFile.Medias { + if media.Link.URL != "" { + link.URL = media.Link.URL + break + } + } + } + + /* + strs := regexp.MustCompile(`e=([0-9]*)`).FindStringSubmatch(lFile.WebContentLink) + if len(strs) == 2 { + timestamp, err := strconv.ParseInt(strs[1], 10, 64) + if err == nil { + expired := time.Duration(timestamp-time.Now().Unix()) * time.Second + link.Expiration = &expired + } + } + */ + return link, nil +} + +func (xc *XunLeiCommon) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(&base.Json{ + "kind": FOLDER, + "name": dirName, + "parent_id": parentDir.GetID(), + }) + }, nil) + return err +} + +func (xc *XunLeiCommon) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + _, err := xc.Request(FILE_API_URL+":batchMove", http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(&base.Json{ + "to": base.Json{"parent_id": dstDir.GetID()}, + "ids": []string{srcObj.GetID()}, + }) + }, nil) + return err +} + +func (xc *XunLeiCommon) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + _, err := xc.Request(FILE_API_URL+"/{fileID}", http.MethodPatch, func(r *resty.Request) { + r.SetContext(ctx) + r.SetPathParam("fileID", srcObj.GetID()) + r.SetBody(&base.Json{"name": newName}) + }, nil) + return err +} + +func (xc *XunLeiCommon) Offline(ctx context.Context, args model.OtherArgs) (interface{}, error) { + _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetHeaders(map[string]string{ + "X-Device-Id": xc.DeviceID, + "User-Agent": xc.UserAgent, + "Peer-Id": xc.DeviceID, + "client_id": xc.ClientID, + "x-client-id": xc.ClientID, + "X-Guid": xc.DeviceID, + }) + r.SetBody(&base.Json{ + "kind": "drive#file", + "name": "", + "parent_id": args.Obj.GetID(), + "upload_type": "UPLOAD_TYPE_URL", + "url": &base.Json{ + "url": args.Data, + "params": "{}", + "parent_id": args.Obj.GetID(), + }, + }) + }, nil) + if err != nil { + return nil, err + } + return "ok", nil +} + +func (xc *XunLeiCommon) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + _, err := xc.Request(FILE_API_URL+":batchCopy", http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(&base.Json{ + "to": base.Json{"parent_id": dstDir.GetID()}, + "ids": []string{srcObj.GetID()}, + }) + }, nil) + return err +} + +func (xc *XunLeiCommon) Remove(ctx context.Context, obj model.Obj) error { + _, err := xc.Request(FILE_API_URL+"/{fileID}/trash", http.MethodPatch, func(r *resty.Request) { + r.SetContext(ctx) + r.SetPathParam("fileID", obj.GetID()) + r.SetBody("{}") + }, nil) + return err +} + +func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + hi := stream.GetHash() + gcid := hi.GetHash(hash_extend.GCID) + if len(gcid) < hash_extend.GCID.Width { + tFile, err := stream.CacheFullInTempFile() + if err != nil { + return err + } + + gcid, err = utils.HashFile(hash_extend.GCID, tFile, stream.GetSize()) + if err != nil { + return err + } + } + + var resp UploadTaskResponse + _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(&base.Json{ + "kind": FILE, + "parent_id": dstDir.GetID(), + "name": stream.GetName(), + "size": stream.GetSize(), + "hash": gcid, + "upload_type": UPLOAD_TYPE_RESUMABLE, + }) + }, &resp) + if err != nil { + return err + } + + param := resp.Resumable.Params + if resp.UploadType == UPLOAD_TYPE_RESUMABLE { + param.Endpoint = strings.TrimLeft(param.Endpoint, param.Bucket+".") + s, err := session.NewSession(&aws.Config{ + Credentials: credentials.NewStaticCredentials(param.AccessKeyID, param.AccessKeySecret, param.SecurityToken), + Region: aws.String("xunlei"), + Endpoint: aws.String(param.Endpoint), + }) + if err != nil { + return err + } + uploader := s3manager.NewUploader(s) + if stream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { + uploader.PartSize = stream.GetSize() / (s3manager.MaxUploadParts - 1) + } + _, err = uploader.UploadWithContext(ctx, &s3manager.UploadInput{ + Bucket: aws.String(param.Bucket), + Key: aws.String(param.Key), + Expires: aws.Time(param.Expiration), + Body: stream, + }) + return err + } + return nil +} + +func (xc *XunLeiCommon) getFiles(ctx context.Context, folderId string) ([]model.Obj, error) { + files := make([]model.Obj, 0) + var pageToken string + for { + var fileList FileList + _, err := xc.Request(FILE_API_URL, http.MethodGet, func(r *resty.Request) { + r.SetContext(ctx) + r.SetQueryParams(map[string]string{ + "space": "", + "__type": "drive", + "refresh": "true", + "__sync": "true", + "parent_id": folderId, + "page_token": pageToken, + "with_audit": "true", + "limit": "100", + "filters": `{"phase":{"eq":"PHASE_TYPE_COMPLETE"},"trashed":{"eq":false}}`, + }) + }, &fileList) + if err != nil { + return nil, err + } + + for i := 0; i < len(fileList.Files); i++ { + files = append(files, &fileList.Files[i]) + } + + if fileList.NextPageToken == "" { + break + } + pageToken = fileList.NextPageToken + } + return files, nil +} + +// 设置刷新Token的方法 +func (xc *XunLeiCommon) SetRefreshTokenFunc(fn func() error) { + xc.refreshTokenFunc = fn +} + +// 设置Token +func (xc *XunLeiCommon) SetTokenResp(tr *TokenResp) { + xc.TokenResp = tr +} + +// 携带Authorization和CaptchaToken的请求 +func (xc *XunLeiCommon) Request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + data, err := xc.Common.Request(url, method, func(req *resty.Request) { + req.SetHeaders(map[string]string{ + "Authorization": xc.Token(), + "X-Captcha-Token": xc.GetCaptchaToken(), + }) + if callback != nil { + callback(req) + } + }, resp) + + errResp, ok := err.(*ErrResp) + if !ok { + return nil, err + } + + switch errResp.ErrorCode { + case 0: + return data, nil + case 4122, 4121, 10, 16: + if xc.refreshTokenFunc != nil { + if err = xc.refreshTokenFunc(); err == nil { + break + } + } + return nil, err + case 9: // 验证码token过期 + if err = xc.RefreshCaptchaTokenAtLogin(GetAction(method, url), xc.UserID); err != nil { + return nil, err + } + default: + return nil, err + } + return xc.Request(url, method, callback, resp) +} + +// 刷新Token +func (xc *XunLeiCommon) RefreshToken(refreshToken string) (*TokenResp, error) { + var resp TokenResp + _, err := xc.Common.Request(XLUSER_API_URL+"/auth/token", http.MethodPost, func(req *resty.Request) { + req.SetBody(&base.Json{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": xc.ClientID, + "client_secret": xc.ClientSecret, + }) + }, &resp) + if err != nil { + return nil, err + } + + if resp.RefreshToken == "" { + return nil, errs.EmptyToken + } + return &resp, nil +} + +// 登录 +func (xc *XunLeiCommon) Login(username, password string) (*TokenResp, error) { + url := XLUSER_API_URL + "/auth/signin" + err := xc.RefreshCaptchaTokenInLogin(GetAction(http.MethodPost, url), username) + if err != nil { + return nil, err + } + + var resp TokenResp + _, err = xc.Common.Request(url, http.MethodPost, func(req *resty.Request) { + req.SetBody(&SignInRequest{ + CaptchaToken: xc.GetCaptchaToken(), + ClientID: xc.ClientID, + ClientSecret: xc.ClientSecret, + Username: username, + Password: password, + }) + }, &resp) + if err != nil { + return nil, err + } + return &resp, nil +} + +func (xc *XunLeiCommon) IsLogin() bool { + if xc.TokenResp == nil { + return false + } + _, err := xc.Request(XLUSER_API_URL+"/user/me", http.MethodGet, nil, nil) + return err == nil +} diff --git a/drivers/thunder/meta.go b/drivers/thunder/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..12b01cbaa16cb7487533c6ac9a9d95fbdacaf03c --- /dev/null +++ b/drivers/thunder/meta.go @@ -0,0 +1,102 @@ +package thunder + +import ( + "crypto/md5" + "encoding/hex" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" +) + +// 高级设置 +type ExpertAddition struct { + driver.RootID + + LoginType string `json:"login_type" type:"select" options:"user,refresh_token" default:"user"` + SignType string `json:"sign_type" type:"select" options:"algorithms,captcha_sign" default:"algorithms"` + + // 登录方式1 + Username string `json:"username" required:"true" help:"login type is user,this is required"` + Password string `json:"password" required:"true" help:"login type is user,this is required"` + // 登录方式2 + RefreshToken string `json:"refresh_token" required:"true" help:"login type is refresh_token,this is required"` + + // 签名方法1 + Algorithms string `json:"algorithms" required:"true" help:"sign type is algorithms,this is required" default:"HPxr4BVygTQVtQkIMwQH33ywbgYG5l4JoR,GzhNkZ8pOBsCY+7,v+l0ImTpG7c7/,e5ztohgVXNP,t,EbXUWyVVqQbQX39Mbjn2geok3/0WEkAVxeqhtx857++kjJiRheP8l77gO,o7dvYgbRMOpHXxCs,6MW8TD8DphmakaxCqVrfv7NReRRN7ck3KLnXBculD58MvxjFRqT+,kmo0HxCKVfmxoZswLB4bVA/dwqbVAYghSb,j,4scKJNdd7F27Hv7tbt"` + // 签名方法2 + CaptchaSign string `json:"captcha_sign" required:"true" help:"sign type is captcha_sign,this is required"` + Timestamp string `json:"timestamp" required:"true" help:"sign type is captcha_sign,this is required"` + + // 验证码 + CaptchaToken string `json:"captcha_token"` + + // 必要且影响登录,由签名决定 + DeviceID string `json:"device_id" required:"true" default:"9aa5c268e7bcfc197a9ad88e2fb330e5"` + ClientID string `json:"client_id" required:"true" default:"Xp6vsxz_7IYVw2BB"` + ClientSecret string `json:"client_secret" required:"true" default:"Xp6vsy4tN9toTVdMSpomVdXpRmES"` + ClientVersion string `json:"client_version" required:"true" default:"7.51.0.8196"` + PackageName string `json:"package_name" required:"true" default:"com.xunlei.downloadprovider"` + + //不影响登录,影响下载速度 + UserAgent string `json:"user_agent" required:"true" default:"ANDROID-com.xunlei.downloadprovider/7.51.0.8196 netWorkType/4G appid/40 deviceName/Xiaomi_M2004j7ac deviceModel/M2004J7AC OSVersion/12 protocolVersion/301 platformVersion/10 sdkVersion/220200 Oauth2Client/0.9 (Linux 4_14_186-perf-gdcf98eab238b) (JAVA 0)"` + DownloadUserAgent string `json:"download_user_agent" required:"true" default:"Dalvik/2.1.0 (Linux; U; Android 12; M2004J7AC Build/SP1A.210812.016)"` + + //优先使用视频链接代替下载链接 + UseVideoUrl bool `json:"use_video_url"` +} + +// 登录特征,用于判断是否重新登录 +func (i *ExpertAddition) GetIdentity() string { + hash := md5.New() + if i.LoginType == "refresh_token" { + hash.Write([]byte(i.RefreshToken)) + } else { + hash.Write([]byte(i.Username + i.Password)) + } + + if i.SignType == "captcha_sign" { + hash.Write([]byte(i.CaptchaSign + i.Timestamp)) + } else { + hash.Write([]byte(i.Algorithms)) + } + + hash.Write([]byte(i.DeviceID)) + hash.Write([]byte(i.ClientID)) + hash.Write([]byte(i.ClientSecret)) + hash.Write([]byte(i.ClientVersion)) + hash.Write([]byte(i.PackageName)) + return hex.EncodeToString(hash.Sum(nil)) +} + +type Addition struct { + driver.RootID + Username string `json:"username" required:"true"` + Password string `json:"password" required:"true"` + CaptchaToken string `json:"captcha_token"` +} + +// 登录特征,用于判断是否重新登录 +func (i *Addition) GetIdentity() string { + return utils.GetMD5EncodeStr(i.Username + i.Password) +} + +var config = driver.Config{ + Name: "Thunder", + LocalSort: true, + OnlyProxy: true, +} + +var configExpert = driver.Config{ + Name: "ThunderExpert", + LocalSort: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Thunder{} + }) + op.RegisterDriver(func() driver.Driver { + return &ThunderExpert{} + }) +} diff --git a/drivers/thunder/types.go b/drivers/thunder/types.go new file mode 100644 index 0000000000000000000000000000000000000000..7c223673448d46de8902d40d9390df9987e7ea87 --- /dev/null +++ b/drivers/thunder/types.go @@ -0,0 +1,206 @@ +package thunder + +import ( + "fmt" + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" +) + +type ErrResp struct { + ErrorCode int64 `json:"error_code"` + ErrorMsg string `json:"error"` + ErrorDescription string `json:"error_description"` + // ErrorDetails interface{} `json:"error_details"` +} + +func (e *ErrResp) IsError() bool { + return e.ErrorCode != 0 || e.ErrorMsg != "" || e.ErrorDescription != "" +} + +func (e *ErrResp) Error() string { + return fmt.Sprintf("ErrorCode: %d ,Error: %s ,ErrorDescription: %s ", e.ErrorCode, e.ErrorMsg, e.ErrorDescription) +} + +/* +* 验证码Token +**/ +type CaptchaTokenRequest struct { + Action string `json:"action"` + CaptchaToken string `json:"captcha_token"` + ClientID string `json:"client_id"` + DeviceID string `json:"device_id"` + Meta map[string]string `json:"meta"` + RedirectUri string `json:"redirect_uri"` +} + +type CaptchaTokenResponse struct { + CaptchaToken string `json:"captcha_token"` + ExpiresIn int64 `json:"expires_in"` + Url string `json:"url"` +} + +/* +* 登录 +**/ +type TokenResp struct { + TokenType string `json:"token_type"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + + Sub string `json:"sub"` + UserID string `json:"user_id"` +} + +func (t *TokenResp) Token() string { + return fmt.Sprint(t.TokenType, " ", t.AccessToken) +} + +type SignInRequest struct { + CaptchaToken string `json:"captcha_token"` + + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + + Username string `json:"username"` + Password string `json:"password"` +} + +/* +* 文件 +**/ +type FileList struct { + Kind string `json:"kind"` + NextPageToken string `json:"next_page_token"` + Files []Files `json:"files"` + Version string `json:"version"` + VersionOutdated bool `json:"version_outdated"` +} + +type Link struct { + URL string `json:"url"` + Token string `json:"token"` + Expire time.Time `json:"expire"` + Type string `json:"type"` +} + +var _ model.Obj = (*Files)(nil) + +type Files struct { + Kind string `json:"kind"` + ID string `json:"id"` + ParentID string `json:"parent_id"` + Name string `json:"name"` + //UserID string `json:"user_id"` + Size string `json:"size"` + //Revision string `json:"revision"` + //FileExtension string `json:"file_extension"` + //MimeType string `json:"mime_type"` + //Starred bool `json:"starred"` + WebContentLink string `json:"web_content_link"` + CreatedTime time.Time `json:"created_time"` + ModifiedTime time.Time `json:"modified_time"` + IconLink string `json:"icon_link"` + ThumbnailLink string `json:"thumbnail_link"` + // Md5Checksum string `json:"md5_checksum"` + Hash string `json:"hash"` + // Links map[string]Link `json:"links"` + // Phase string `json:"phase"` + // Audit struct { + // Status string `json:"status"` + // Message string `json:"message"` + // Title string `json:"title"` + // } `json:"audit"` + Medias []struct { + //Category string `json:"category"` + //IconLink string `json:"icon_link"` + //IsDefault bool `json:"is_default"` + //IsOrigin bool `json:"is_origin"` + //IsVisible bool `json:"is_visible"` + Link Link `json:"link"` + //MediaID string `json:"media_id"` + //MediaName string `json:"media_name"` + //NeedMoreQuota bool `json:"need_more_quota"` + //Priority int `json:"priority"` + //RedirectLink string `json:"redirect_link"` + //ResolutionName string `json:"resolution_name"` + // Video struct { + // AudioCodec string `json:"audio_codec"` + // BitRate int `json:"bit_rate"` + // Duration int `json:"duration"` + // FrameRate int `json:"frame_rate"` + // Height int `json:"height"` + // VideoCodec string `json:"video_codec"` + // VideoType string `json:"video_type"` + // Width int `json:"width"` + // } `json:"video"` + // VipTypes []string `json:"vip_types"` + } `json:"medias"` + Trashed bool `json:"trashed"` + DeleteTime string `json:"delete_time"` + OriginalURL string `json:"original_url"` + //Params struct{} `json:"params"` + //OriginalFileIndex int `json:"original_file_index"` + //Space string `json:"space"` + //Apps []interface{} `json:"apps"` + //Writable bool `json:"writable"` + //FolderType string `json:"folder_type"` + //Collection interface{} `json:"collection"` +} + +func (c *Files) GetHash() utils.HashInfo { + return utils.NewHashInfo(hash_extend.GCID, c.Hash) +} + +func (c *Files) GetSize() int64 { size, _ := strconv.ParseInt(c.Size, 10, 64); return size } +func (c *Files) GetName() string { return c.Name } +func (c *Files) CreateTime() time.Time { return c.CreatedTime } +func (c *Files) ModTime() time.Time { return c.ModifiedTime } +func (c *Files) IsDir() bool { return c.Kind == FOLDER } +func (c *Files) GetID() string { return c.ID } +func (c *Files) GetPath() string { return "" } +func (c *Files) Thumb() string { return c.ThumbnailLink } + +/* +* 上传 +**/ +type UploadTaskResponse struct { + UploadType string `json:"upload_type"` + + /*//UPLOAD_TYPE_FORM + Form struct { + //Headers struct{} `json:"headers"` + Kind string `json:"kind"` + Method string `json:"method"` + MultiParts struct { + OSSAccessKeyID string `json:"OSSAccessKeyId"` + Signature string `json:"Signature"` + Callback string `json:"callback"` + Key string `json:"key"` + Policy string `json:"policy"` + XUserData string `json:"x:user_data"` + } `json:"multi_parts"` + URL string `json:"url"` + } `json:"form"`*/ + + //UPLOAD_TYPE_RESUMABLE + Resumable struct { + Kind string `json:"kind"` + Params struct { + AccessKeyID string `json:"access_key_id"` + AccessKeySecret string `json:"access_key_secret"` + Bucket string `json:"bucket"` + Endpoint string `json:"endpoint"` + Expiration time.Time `json:"expiration"` + Key string `json:"key"` + SecurityToken string `json:"security_token"` + } `json:"params"` + Provider string `json:"provider"` + } `json:"resumable"` + + File Files `json:"file"` +} diff --git a/drivers/thunder/util.go b/drivers/thunder/util.go new file mode 100644 index 0000000000000000000000000000000000000000..f6dec3260cf2e6a51cf0fa812f51b314b79f468b --- /dev/null +++ b/drivers/thunder/util.go @@ -0,0 +1,202 @@ +package thunder + +import ( + "crypto/sha1" + "encoding/hex" + "fmt" + "io" + "net/http" + "regexp" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +const ( + API_URL = "https://api-pan.xunlei.com/drive/v1" + FILE_API_URL = API_URL + "/files" + XLUSER_API_URL = "https://xluser-ssl.xunlei.com/v1" +) + +const ( + FOLDER = "drive#folder" + FILE = "drive#file" + RESUMABLE = "drive#resumable" +) + +const ( + UPLOAD_TYPE_UNKNOWN = "UPLOAD_TYPE_UNKNOWN" + //UPLOAD_TYPE_FORM = "UPLOAD_TYPE_FORM" + UPLOAD_TYPE_RESUMABLE = "UPLOAD_TYPE_RESUMABLE" + UPLOAD_TYPE_URL = "UPLOAD_TYPE_URL" +) + +func GetAction(method string, url string) string { + urlpath := regexp.MustCompile(`://[^/]+((/[^/\s?#]+)*)`).FindStringSubmatch(url)[1] + return method + ":" + urlpath +} + +type Common struct { + client *resty.Client + + captchaToken string + + // 签名相关,二选一 + Algorithms []string + Timestamp, CaptchaSign string + + // 必要值,签名相关 + DeviceID string + ClientID string + ClientSecret string + ClientVersion string + PackageName string + UserAgent string + DownloadUserAgent string + UseVideoUrl bool + + // 验证码token刷新成功回调 + refreshCTokenCk func(token string) +} + +func (c *Common) SetCaptchaToken(captchaToken string) { + c.captchaToken = captchaToken +} +func (c *Common) GetCaptchaToken() string { + return c.captchaToken +} + +// 刷新验证码token(登录后) +func (c *Common) RefreshCaptchaTokenAtLogin(action, userID string) error { + metas := map[string]string{ + "client_version": c.ClientVersion, + "package_name": c.PackageName, + "user_id": userID, + } + metas["timestamp"], metas["captcha_sign"] = c.GetCaptchaSign() + return c.refreshCaptchaToken(action, metas) +} + +// 刷新验证码token(登录时) +func (c *Common) RefreshCaptchaTokenInLogin(action, username string) error { + metas := make(map[string]string) + if ok, _ := regexp.MatchString(`\w+([-+.]\w+)*@\w+([-.]\w+)*\.\w+([-.]\w+)*`, username); ok { + metas["email"] = username + } else if len(username) >= 11 && len(username) <= 18 { + metas["phone_number"] = username + } else { + metas["username"] = username + } + return c.refreshCaptchaToken(action, metas) +} + +// 获取验证码签名 +func (c *Common) GetCaptchaSign() (timestamp, sign string) { + if len(c.Algorithms) == 0 { + return c.Timestamp, c.CaptchaSign + } + timestamp = fmt.Sprint(time.Now().UnixMilli()) + str := fmt.Sprint(c.ClientID, c.ClientVersion, c.PackageName, c.DeviceID, timestamp) + for _, algorithm := range c.Algorithms { + str = utils.GetMD5EncodeStr(str + algorithm) + } + sign = "1." + str + return +} + +// 刷新验证码token +func (c *Common) refreshCaptchaToken(action string, metas map[string]string) error { + param := CaptchaTokenRequest{ + Action: action, + CaptchaToken: c.captchaToken, + ClientID: c.ClientID, + DeviceID: c.DeviceID, + Meta: metas, + RedirectUri: "xlaccsdk01://xunlei.com/callback?state=harbor", + } + var e ErrResp + var resp CaptchaTokenResponse + _, err := c.Request(XLUSER_API_URL+"/shield/captcha/init", http.MethodPost, func(req *resty.Request) { + req.SetError(&e).SetBody(param) + }, &resp) + + if err != nil { + return err + } + + if e.IsError() { + return &e + } + + if resp.Url != "" { + return fmt.Errorf(`need verify: Click Here`, resp.Url) + } + + if resp.CaptchaToken == "" { + return fmt.Errorf("empty captchaToken") + } + + if c.refreshCTokenCk != nil { + c.refreshCTokenCk(resp.CaptchaToken) + } + c.SetCaptchaToken(resp.CaptchaToken) + return nil +} + +// 只有基础信息的请求 +func (c *Common) Request(url, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := c.client.R().SetHeaders(map[string]string{ + "user-agent": c.UserAgent, + "accept": "application/json;charset=UTF-8", + "x-device-id": c.DeviceID, + "x-client-id": c.ClientID, + "x-client-version": c.ClientVersion, + }) + + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + + var erron ErrResp + utils.Json.Unmarshal(res.Body(), &erron) + if erron.IsError() { + return nil, &erron + } + + return res.Body(), nil +} + +// 计算文件Gcid +func getGcid(r io.Reader, size int64) (string, error) { + calcBlockSize := func(j int64) int64 { + var psize int64 = 0x40000 + for float64(j)/float64(psize) > 0x200 && psize < 0x200000 { + psize = psize << 1 + } + return psize + } + + hash1 := sha1.New() + hash2 := sha1.New() + readSize := calcBlockSize(size) + for { + hash2.Reset() + if n, err := io.CopyN(hash2, r, readSize); err != nil && n == 0 { + if err != io.EOF { + return "", err + } + break + } + hash1.Write(hash2.Sum(nil)) + } + return hex.EncodeToString(hash1.Sum(nil)), nil +} diff --git a/drivers/thunder_browser/driver.go b/drivers/thunder_browser/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..cc5c7dad61c931a7f1ae187c707189d8d0f99a33 --- /dev/null +++ b/drivers/thunder_browser/driver.go @@ -0,0 +1,843 @@ +package thunder_browser + +import ( + "context" + "fmt" + "net/http" + "regexp" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/go-resty/resty/v2" +) + +type ThunderBrowser struct { + *XunLeiBrowserCommon + model.Storage + Addition + + identity string +} + +func (x *ThunderBrowser) Config() driver.Config { + return config +} + +func (x *ThunderBrowser) GetAddition() driver.Additional { + return &x.Addition +} + +func (x *ThunderBrowser) Init(ctx context.Context) (err error) { + + spaceTokenFunc := func() error { + // 如果用户未设置 "超级保险柜" 密码 则直接返回 + if x.SafePassword == "" { + return nil + } + // 通过 GetSafeAccessToken 获取 + token, err := x.GetSafeAccessToken(x.SafePassword) + x.SetSpaceTokenResp(token) + return err + } + + // 初始化所需参数 + if x.XunLeiBrowserCommon == nil { + x.XunLeiBrowserCommon = &XunLeiBrowserCommon{ + Common: &Common{ + client: base.NewRestyClient(), + Algorithms: []string{ + "x+I5XiTByg", + "6QU1x5DqGAV3JKg6h", + "VI1vL1WXr7st0es", + "n+/3yhlrnKs4ewhLgZhZ5ITpt554", + "UOip2PE7BLIEov/ZX6VOnsz", + "Q70h9lpViNCOC8sGVkar9o22LhBTjfP", + "IVHFuB1JcMlaZHnW", + "bKE", + "HZRbwxOiQx+diNopi6Nu", + "fwyasXgYL3rP314331b", + "LWxXAiSW4", + "UlWIjv1HGrC6Ngmt4Nohx", + "FOa+Lc0bxTDpTwIh2", + "0+RY", + "xmRVMqokHHpvsiH0", + }, + DeviceID: utils.GetMD5EncodeStr(x.Username + x.Password), + ClientID: "ZUBzD9J_XPXfn7f7", + ClientSecret: "yESVmHecEe6F0aou69vl-g", + ClientVersion: "1.0.7.1938", + PackageName: "com.xunlei.browser", + UserAgent: "ANDROID-com.xunlei.browser/1.0.7.1938 netWorkType/5G appid/22062 deviceName/Xiaomi_M2004j7ac deviceModel/M2004J7AC OSVersion/12 protocolVersion/301 platformVersion/10 sdkVersion/233100 Oauth2Client/0.9 (Linux 4_14_186-perf-gddfs8vbb238b) (JAVA 0)", + DownloadUserAgent: "AndroidDownloadManager/12 (Linux; U; Android 12; M2004J7AC Build/SP1A.210812.016)", + UseVideoUrl: x.UseVideoUrl, + + refreshCTokenCk: func(token string) { + x.CaptchaToken = token + op.MustSaveDriverStorage(x) + }, + }, + refreshTokenFunc: func() error { + // 通过RefreshToken刷新 + token, err := x.RefreshToken(x.TokenResp.RefreshToken) + if err != nil { + // 重新登录 + token, err = x.Login(x.Username, x.Password) + if err != nil { + x.GetStorage().SetStatus(fmt.Sprintf("%+v", err.Error())) + op.MustSaveDriverStorage(x) + } + } + x.SetTokenResp(token) + return err + }, + } + } + + // 自定义验证码token + ctoekn := strings.TrimSpace(x.CaptchaToken) + if ctoekn != "" { + x.SetCaptchaToken(ctoekn) + } + x.XunLeiBrowserCommon.UseVideoUrl = x.UseVideoUrl + x.Addition.RootFolderID = x.RootFolderID + // 防止重复登录 + identity := x.GetIdentity() + if x.identity != identity || !x.IsLogin() { + x.identity = identity + // 登录 + token, err := x.Login(x.Username, x.Password) + if err != nil { + return err + } + x.SetTokenResp(token) + } + + // 获取 spaceToken + err = spaceTokenFunc() + if err != nil { + return err + } + + return nil +} + +func (x *ThunderBrowser) Drop(ctx context.Context) error { + return nil +} + +type ThunderBrowserExpert struct { + *XunLeiBrowserCommon + model.Storage + ExpertAddition + + identity string +} + +func (x *ThunderBrowserExpert) Config() driver.Config { + return configExpert +} + +func (x *ThunderBrowserExpert) GetAddition() driver.Additional { + return &x.ExpertAddition +} + +func (x *ThunderBrowserExpert) Init(ctx context.Context) (err error) { + + spaceTokenFunc := func() error { + // 如果用户未设置 "超级保险柜" 密码 则直接返回 + if x.SafePassword == "" { + return nil + } + // 通过 GetSafeAccessToken 获取 + token, err := x.GetSafeAccessToken(x.SafePassword) + x.SetSpaceTokenResp(token) + return err + } + + // 防止重复登录 + identity := x.GetIdentity() + if identity != x.identity || !x.IsLogin() { + x.identity = identity + x.XunLeiBrowserCommon = &XunLeiBrowserCommon{ + Common: &Common{ + client: base.NewRestyClient(), + + DeviceID: func() string { + if len(x.DeviceID) != 32 { + return utils.GetMD5EncodeStr(x.DeviceID) + } + return x.DeviceID + }(), + ClientID: x.ClientID, + ClientSecret: x.ClientSecret, + ClientVersion: x.ClientVersion, + PackageName: x.PackageName, + UserAgent: x.UserAgent, + DownloadUserAgent: x.DownloadUserAgent, + UseVideoUrl: x.UseVideoUrl, + + refreshCTokenCk: func(token string) { + x.CaptchaToken = token + op.MustSaveDriverStorage(x) + }, + }, + } + + if x.CaptchaToken != "" { + x.SetCaptchaToken(x.CaptchaToken) + } + x.XunLeiBrowserCommon.UseVideoUrl = x.UseVideoUrl + x.ExpertAddition.RootFolderID = x.RootFolderID + // 签名方法 + if x.SignType == "captcha_sign" { + x.Common.Timestamp = x.Timestamp + x.Common.CaptchaSign = x.CaptchaSign + } else { + x.Common.Algorithms = strings.Split(x.Algorithms, ",") + } + + // 登录方式 + if x.LoginType == "refresh_token" { + // 通过RefreshToken登录 + token, err := x.XunLeiBrowserCommon.RefreshToken(x.ExpertAddition.RefreshToken) + if err != nil { + return err + } + x.SetTokenResp(token) + + // 刷新token方法 + x.SetRefreshTokenFunc(func() error { + token, err := x.XunLeiBrowserCommon.RefreshToken(x.TokenResp.RefreshToken) + if err != nil { + x.GetStorage().SetStatus(fmt.Sprintf("%+v", err.Error())) + } + x.SetTokenResp(token) + op.MustSaveDriverStorage(x) + return err + }) + + err = spaceTokenFunc() + if err != nil { + return err + } + + } else { + // 通过用户密码登录 + token, err := x.Login(x.Username, x.Password) + if err != nil { + return err + } + x.SetTokenResp(token) + x.SetRefreshTokenFunc(func() error { + token, err := x.XunLeiBrowserCommon.RefreshToken(x.TokenResp.RefreshToken) + if err != nil { + token, err = x.Login(x.Username, x.Password) + if err != nil { + x.GetStorage().SetStatus(fmt.Sprintf("%+v", err.Error())) + } + } + x.SetTokenResp(token) + op.MustSaveDriverStorage(x) + return err + }) + + err = spaceTokenFunc() + if err != nil { + return err + } + } + } else { + // 仅修改验证码token + if x.CaptchaToken != "" { + x.SetCaptchaToken(x.CaptchaToken) + } + + err = spaceTokenFunc() + if err != nil { + return err + } + + x.XunLeiBrowserCommon.UserAgent = x.UserAgent + x.XunLeiBrowserCommon.DownloadUserAgent = x.DownloadUserAgent + x.XunLeiBrowserCommon.UseVideoUrl = x.UseVideoUrl + x.ExpertAddition.RootFolderID = x.RootFolderID + } + + return nil +} + +func (x *ThunderBrowserExpert) Drop(ctx context.Context) error { + return nil +} + +func (x *ThunderBrowserExpert) SetTokenResp(token *TokenResp) { + x.XunLeiBrowserCommon.SetTokenResp(token) + if token != nil { + x.ExpertAddition.RefreshToken = token.RefreshToken + } +} + +type XunLeiBrowserCommon struct { + *Common + *TokenResp // 登录信息 + + refreshTokenFunc func() error +} + +func (xc *XunLeiBrowserCommon) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + return xc.getFiles(ctx, dir.GetID(), args.ReqPath) +} + +func (xc *XunLeiBrowserCommon) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var lFile Files + + params := map[string]string{ + "_magic": "2021", + "space": "SPACE_BROWSER", + "thumbnail_size": "SIZE_LARGE", + "with": "url", + } + // 对 "迅雷云盘" 内的文件 特殊处理 + if file.GetPath() == ThunderDriveFileID { + params = map[string]string{} + } else if file.GetPath() == ThunderBrowserDriveSafeFileID { + // 对 "超级保险箱" 内的文件 特殊处理 + params["space"] = "SPACE_BROWSER_SAFE" + } + + _, err := xc.Request(FILE_API_URL+"/{fileID}", http.MethodGet, func(r *resty.Request) { + r.SetContext(ctx) + r.SetPathParam("fileID", file.GetID()) + r.SetQueryParams(params) + //r.SetQueryParam("space", "") + }, &lFile) + if err != nil { + return nil, err + } + link := &model.Link{ + URL: lFile.WebContentLink, + Header: http.Header{ + "User-Agent": {xc.DownloadUserAgent}, + }, + } + + if xc.UseVideoUrl { + for _, media := range lFile.Medias { + if media.Link.URL != "" { + link.URL = media.Link.URL + break + } + } + } + return link, nil +} + +func (xc *XunLeiBrowserCommon) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + js := base.Json{ + "kind": FOLDER, + "name": dirName, + "parent_id": parentDir.GetID(), + "space": "SPACE_BROWSER", + } + if parentDir.GetPath() == ThunderDriveFileID { + js = base.Json{ + "kind": FOLDER, + "name": dirName, + "parent_id": parentDir.GetID(), + } + } else if parentDir.GetPath() == ThunderBrowserDriveSafeFileID { + js = base.Json{ + "kind": FOLDER, + "name": dirName, + "parent_id": parentDir.GetID(), + "space": "SPACE_BROWSER_SAFE", + } + } + _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(&js) + }, nil) + return err +} + +func (xc *XunLeiBrowserCommon) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + + srcSpace := "SPACE_BROWSER" + dstSpace := "SPACE_BROWSER" + + // 对 "超级保险箱" 内的文件 特殊处理 + if srcObj.GetPath() == ThunderBrowserDriveSafeFileID { + srcSpace = "SPACE_BROWSER_SAFE" + } + if dstDir.GetPath() == ThunderBrowserDriveSafeFileID { + dstSpace = "SPACE_BROWSER_SAFE" + } + + params := map[string]string{ + "_from": dstSpace, + } + js := base.Json{ + "to": base.Json{"parent_id": dstDir.GetID(), "space": dstSpace}, + "space": srcSpace, + "ids": []string{srcObj.GetID()}, + } + // 对 "迅雷云盘" 内的文件 特殊处理 + if srcObj.GetPath() == ThunderDriveFileID { + params = map[string]string{} + js = base.Json{ + "to": base.Json{"parent_id": dstDir.GetID()}, + "ids": []string{srcObj.GetID()}, + } + } + + _, err := xc.Request(FILE_API_URL+":batchMove", http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(&js) + r.SetQueryParams(params) + }, nil) + return err +} + +func (xc *XunLeiBrowserCommon) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + + params := map[string]string{ + "space": "SPACE_BROWSER", + } + // 对 "迅雷云盘" 内的文件 特殊处理 + if srcObj.GetPath() == ThunderDriveFileID { + params = map[string]string{} + } else if srcObj.GetPath() == ThunderBrowserDriveSafeFileID { + // 对 "超级保险箱" 内的文件 特殊处理 + params = map[string]string{ + "space": "SPACE_BROWSER_SAFE", + } + } + + _, err := xc.Request(FILE_API_URL+"/{fileID}", http.MethodPatch, func(r *resty.Request) { + r.SetContext(ctx) + r.SetPathParam("fileID", srcObj.GetID()) + r.SetBody(&base.Json{"name": newName}) + r.SetQueryParams(params) + }, nil) + return err +} + +func (xc *XunLeiBrowserCommon) Offline(ctx context.Context, args model.OtherArgs) (interface{}, error) { + _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetHeaders(map[string]string{ + "X-Device-Id": xc.DeviceID, + "User-Agent": xc.UserAgent, + "Peer-Id": xc.DeviceID, + "client_id": xc.ClientID, + "x-client-id": xc.ClientID, + "X-Guid": xc.DeviceID, + }) + r.SetBody(&base.Json{ + "kind": "drive#file", + "name": "", + "parent_id": args.Obj.GetID(), + "upload_type": "UPLOAD_TYPE_URL", + "url": &base.Json{ + "url": args.Data, + "params": "{}", + "parent_id": args.Obj.GetID(), + }, + }) + }, nil) + if err != nil { + return nil, err + } + return "ok", nil +} + +func (xc *XunLeiBrowserCommon) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + + srcSpace := "SPACE_BROWSER" + dstSpace := "SPACE_BROWSER" + + // 对 "超级保险箱" 内的文件 特殊处理 + if srcObj.GetPath() == ThunderBrowserDriveSafeFileID { + srcSpace = "SPACE_BROWSER_SAFE" + } + if dstDir.GetPath() == ThunderBrowserDriveSafeFileID { + dstSpace = "SPACE_BROWSER_SAFE" + } + + params := map[string]string{ + "_from": dstSpace, + } + js := base.Json{ + "to": base.Json{"parent_id": dstDir.GetID(), "space": dstSpace}, + "space": srcSpace, + "ids": []string{srcObj.GetID()}, + } + // 对 "迅雷云盘" 内的文件 特殊处理 + if srcObj.GetPath() == ThunderDriveFileID { + params = map[string]string{} + js = base.Json{ + "to": base.Json{"parent_id": dstDir.GetID()}, + "ids": []string{srcObj.GetID()}, + } + } + + _, err := xc.Request(FILE_API_URL+":batchCopy", http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(&js) + r.SetQueryParams(params) + }, nil) + return err +} + +func (xc *XunLeiBrowserCommon) Remove(ctx context.Context, obj model.Obj) error { + + js := base.Json{ + "ids": []string{obj.GetID()}, + "space": "SPACE_BROWSER", + } + // 对 "迅雷云盘" 内的文件 特殊处理 + if obj.GetPath() == ThunderDriveFileID { + js = base.Json{ + "ids": []string{obj.GetID()}, + } + } else if obj.GetPath() == ThunderBrowserDriveSafeFileID { + // 对 "超级保险箱" 内的文件 特殊处理 + js = base.Json{ + "ids": []string{obj.GetID()}, + "space": "SPACE_BROWSER_SAFE", + } + } + + if xc.RemoveWay == "delete" && obj.GetPath() == ThunderDriveFileID { + _, err := xc.Request(FILE_API_URL+"/{fileID}/trash", http.MethodPatch, func(r *resty.Request) { + r.SetContext(ctx) + r.SetPathParam("fileID", obj.GetID()) + r.SetBody("{}") + }, nil) + return err + } else if obj.GetPath() == ThunderBrowserDriveSafeFileID { + _, err := xc.Request(FILE_API_URL+":batchDelete", http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(&js) + }, nil) + return err + } + + _, err := xc.Request(FILE_API_URL+":batchTrash", http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(&js) + }, nil) + return err + +} + +func (xc *XunLeiBrowserCommon) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + hi := stream.GetHash() + gcid := hi.GetHash(hash_extend.GCID) + if len(gcid) < hash_extend.GCID.Width { + tFile, err := stream.CacheFullInTempFile() + if err != nil { + return err + } + + gcid, err = utils.HashFile(hash_extend.GCID, tFile, stream.GetSize()) + if err != nil { + return err + } + } + + js := base.Json{ + "kind": FILE, + "parent_id": dstDir.GetID(), + "name": stream.GetName(), + "size": stream.GetSize(), + "hash": gcid, + "upload_type": UPLOAD_TYPE_RESUMABLE, + "space": "SPACE_BROWSER", + } + // 对 "迅雷云盘" 内的文件 特殊处理 + if dstDir.GetPath() == ThunderDriveFileID { + js = base.Json{ + "kind": FILE, + "parent_id": dstDir.GetID(), + "name": stream.GetName(), + "size": stream.GetSize(), + "hash": gcid, + "upload_type": UPLOAD_TYPE_RESUMABLE, + } + } else if dstDir.GetPath() == ThunderBrowserDriveSafeFileID { + // 对 "超级保险箱" 内的文件 特殊处理 + js = base.Json{ + "kind": FILE, + "parent_id": dstDir.GetID(), + "name": stream.GetName(), + "size": stream.GetSize(), + "hash": gcid, + "upload_type": UPLOAD_TYPE_RESUMABLE, + "space": "SPACE_BROWSER_SAFE", + } + } + + var resp UploadTaskResponse + _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(&js) + }, &resp) + if err != nil { + return err + } + + param := resp.Resumable.Params + if resp.UploadType == UPLOAD_TYPE_RESUMABLE { + param.Endpoint = strings.TrimLeft(param.Endpoint, param.Bucket+".") + s, err := session.NewSession(&aws.Config{ + Credentials: credentials.NewStaticCredentials(param.AccessKeyID, param.AccessKeySecret, param.SecurityToken), + Region: aws.String("xunlei"), + Endpoint: aws.String(param.Endpoint), + }) + if err != nil { + return err + } + uploader := s3manager.NewUploader(s) + if stream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { + uploader.PartSize = stream.GetSize() / (s3manager.MaxUploadParts - 1) + } + _, err = uploader.UploadWithContext(ctx, &s3manager.UploadInput{ + Bucket: aws.String(param.Bucket), + Key: aws.String(param.Key), + Expires: aws.Time(param.Expiration), + Body: stream, + }) + return err + } + return nil +} + +func (xc *XunLeiBrowserCommon) getFiles(ctx context.Context, folderId string, path string) ([]model.Obj, error) { + files := make([]model.Obj, 0) + var pageToken string + for { + var fileList FileList + folderSpace := "SPACE_BROWSER" + params := map[string]string{ + "parent_id": folderId, + "page_token": pageToken, + "space": folderSpace, + "filters": `{"trashed":{"eq":false}}`, + "with_audit": "true", + "thumbnail_size": "SIZE_LARGE", + } + var fileType int8 + // 处理特殊目录 “迅雷云盘” 设置特殊的 params 以便正常访问 + pattern1 := fmt.Sprintf(`^/.*/%s(/.*)?$`, ThunderDriveFolderName) + thunderDriveMatch, _ := regexp.MatchString(pattern1, path) + // 处理特殊目录 “超级保险箱” 设置特殊的 params 以便正常访问 + pattern2 := fmt.Sprintf(`^/.*/%s(/.*)?$`, ThunderBrowserDriveSafeFolderName) + thunderBrowserDriveSafeMatch, _ := regexp.MatchString(pattern2, path) + + // 如果是 "迅雷云盘" 内的 + if folderId == ThunderDriveFileID || thunderDriveMatch { + params = map[string]string{ + "space": "", + "__type": "drive", + "refresh": "true", + "__sync": "true", + "parent_id": folderId, + "page_token": pageToken, + "with_audit": "true", + "limit": "100", + "filters": `{"phase":{"eq":"PHASE_TYPE_COMPLETE"},"trashed":{"eq":false}}`, + } + // 如果不是 "迅雷云盘"的"根目录" + if folderId == ThunderDriveFileID { + params["parent_id"] = "" + } + fileType = ThunderDriveType + } else if thunderBrowserDriveSafeMatch { + // 如果是 "超级保险箱" 内的 + fileType = ThunderBrowserDriveSafeType + params["space"] = "SPACE_BROWSER_SAFE" + } + + _, err := xc.Request(FILE_API_URL, http.MethodGet, func(r *resty.Request) { + r.SetContext(ctx) + r.SetQueryParams(params) + }, &fileList) + if err != nil { + return nil, err + } + // 对文件夹也进行处理 + fileList.FolderType = fileType + + for i := 0; i < len(fileList.Files); i++ { + file := &fileList.Files[i] + // 标记 文件夹内的文件 + file.FileType = fileList.FolderType + // 解决 "迅雷云盘" 重复出现问题————迅雷后端发送错误 + if file.Name == ThunderDriveFolderName && file.ID == "" && file.FolderType == ThunderDriveFolderType && folderId != "" { + continue + } + // 处理特殊目录 “迅雷云盘” 设置特殊的文件夹ID + if file.Name == ThunderDriveFolderName && file.ID == "" && file.FolderType == ThunderDriveFolderType { + file.ID = ThunderDriveFileID + } else if file.Name == ThunderBrowserDriveSafeFolderName && file.FolderType == ThunderBrowserDriveSafeFolderType { + file.FileType = ThunderBrowserDriveSafeType + } + files = append(files, file) + } + + if fileList.NextPageToken == "" { + break + } + pageToken = fileList.NextPageToken + } + return files, nil +} + +// SetRefreshTokenFunc 设置刷新Token的方法 +func (xc *XunLeiBrowserCommon) SetRefreshTokenFunc(fn func() error) { + xc.refreshTokenFunc = fn +} + +// SetTokenResp 设置Token +func (xc *XunLeiBrowserCommon) SetTokenResp(tr *TokenResp) { + xc.TokenResp = tr +} + +// SetSpaceTokenResp 设置Token +func (xc *XunLeiBrowserCommon) SetSpaceTokenResp(spaceToken string) { + xc.TokenResp.Token = spaceToken +} + +// Request 携带Authorization和CaptchaToken的请求 +func (xc *XunLeiBrowserCommon) Request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + data, err := xc.Common.Request(url, method, func(req *resty.Request) { + req.SetHeaders(map[string]string{ + "Authorization": xc.GetToken(), + "X-Captcha-Token": xc.GetCaptchaToken(), + "X-Space-Authorization": xc.GetSpaceToken(), + }) + if callback != nil { + callback(req) + } + }, resp) + + errResp, ok := err.(*ErrResp) + if !ok { + return nil, err + } + + switch errResp.ErrorCode { + case 0: + return data, nil + case 4122, 4121, 10, 16: + if xc.refreshTokenFunc != nil { + if err = xc.refreshTokenFunc(); err == nil { + break + } + } + return nil, err + case 9: + // space_token 获取失败 + if errResp.ErrorMsg == "space_token_invalid" { + if token, err := xc.GetSafeAccessToken(xc.Token); err != nil { + return nil, err + } else { + xc.SetSpaceTokenResp(token) + } + + } + if errResp.ErrorMsg == "captcha_invalid" { + // 验证码token过期 + if err = xc.RefreshCaptchaTokenAtLogin(GetAction(method, url), xc.UserID); err != nil { + return nil, err + } + } + return nil, err + default: + return nil, err + } + return xc.Request(url, method, callback, resp) +} + +// RefreshToken 刷新Token +func (xc *XunLeiBrowserCommon) RefreshToken(refreshToken string) (*TokenResp, error) { + var resp TokenResp + _, err := xc.Common.Request(XLUSER_API_URL+"/auth/token", http.MethodPost, func(req *resty.Request) { + req.SetBody(&base.Json{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": xc.ClientID, + "client_secret": xc.ClientSecret, + }) + }, &resp) + if err != nil { + return nil, err + } + + if resp.RefreshToken == "" { + return nil, errs.EmptyToken + } + return &resp, nil +} + +// GetSafeAccessToken 获取 超级保险柜 AccessToken +func (xc *XunLeiBrowserCommon) GetSafeAccessToken(safePassword string) (string, error) { + var resp TokenResp + _, err := xc.Request(XLUSER_API_URL+"/password/check", http.MethodPost, func(req *resty.Request) { + req.SetBody(&base.Json{ + "scene": "box", + "password": EncryptPassword(safePassword), + }) + }, &resp) + if err != nil { + return "", err + } + + if resp.Token == "" { + return "", errs.EmptyToken + } + return resp.Token, nil +} + +// Login 登录 +func (xc *XunLeiBrowserCommon) Login(username, password string) (*TokenResp, error) { + url := XLUSER_API_URL + "/auth/signin" + err := xc.RefreshCaptchaTokenInLogin(GetAction(http.MethodPost, url), username) + if err != nil { + return nil, err + } + + var resp TokenResp + _, err = xc.Common.Request(url, http.MethodPost, func(req *resty.Request) { + req.SetBody(&SignInRequest{ + CaptchaToken: xc.GetCaptchaToken(), + ClientID: xc.ClientID, + ClientSecret: xc.ClientSecret, + Username: username, + Password: password, + }) + }, &resp) + if err != nil { + return nil, err + } + return &resp, nil +} + +func (xc *XunLeiBrowserCommon) IsLogin() bool { + if xc.TokenResp == nil { + return false + } + _, err := xc.Request(XLUSER_API_URL+"/user/me", http.MethodGet, nil, nil) + return err == nil +} diff --git a/drivers/thunder_browser/meta.go b/drivers/thunder_browser/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..9d16cd78e5ca81eaa05f7b5a8a0672bb757e4c0e --- /dev/null +++ b/drivers/thunder_browser/meta.go @@ -0,0 +1,108 @@ +package thunder_browser + +import ( + "crypto/md5" + "encoding/hex" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" +) + +// ExpertAddition 高级设置 +type ExpertAddition struct { + driver.RootID + + LoginType string `json:"login_type" type:"select" options:"user,refresh_token" default:"user"` + SignType string `json:"sign_type" type:"select" options:"algorithms,captcha_sign" default:"algorithms"` + + // 登录方式1 + Username string `json:"username" required:"true" help:"login type is user,this is required"` + Password string `json:"password" required:"true" help:"login type is user,this is required"` + SafePassword string `json:"safe_password" required:"false" help:"login type is user,this is required"` // 超级保险箱密码 + // 登录方式2 + RefreshToken string `json:"refresh_token" required:"true" help:"login type is refresh_token,this is required"` + + // 签名方法1 + Algorithms string `json:"algorithms" required:"true" help:"sign type is algorithms,this is required" default:"x+I5XiTByg,6QU1x5DqGAV3JKg6h,VI1vL1WXr7st0es,n+/3yhlrnKs4ewhLgZhZ5ITpt554,UOip2PE7BLIEov/ZX6VOnsz,Q70h9lpViNCOC8sGVkar9o22LhBTjfP,IVHFuB1JcMlaZHnW,bKE,HZRbwxOiQx+diNopi6Nu,fwyasXgYL3rP314331b,LWxXAiSW4,UlWIjv1HGrC6Ngmt4Nohx,FOa+Lc0bxTDpTwIh2,0+RY,xmRVMqokHHpvsiH0"` + // 签名方法2 + CaptchaSign string `json:"captcha_sign" required:"true" help:"sign type is captcha_sign,this is required"` + Timestamp string `json:"timestamp" required:"true" help:"sign type is captcha_sign,this is required"` + + // 验证码 + CaptchaToken string `json:"captcha_token"` + + // 必要且影响登录,由签名决定 + DeviceID string `json:"device_id" required:"true" default:"9aa5c268e7bcfc197a9ad88e2fb330e5"` + ClientID string `json:"client_id" required:"true" default:"ZUBzD9J_XPXfn7f7"` + ClientSecret string `json:"client_secret" required:"true" default:"yESVmHecEe6F0aou69vl-g"` + ClientVersion string `json:"client_version" required:"true" default:"1.0.7.1938"` + PackageName string `json:"package_name" required:"true" default:"com.xunlei.browser"` + + // 不影响登录,影响下载速度 + UserAgent string `json:"user_agent" required:"true" default:"ANDROID-com.xunlei.browser/1.0.7.1938 netWorkType/5G appid/22062 deviceName/Xiaomi_M2004j7ac deviceModel/M2004J7AC OSVersion/12 protocolVersion/301 platformVersion/10 sdkVersion/233100 Oauth2Client/0.9 (Linux 4_14_186-perf-gddfs8vbb238b) (JAVA 0)"` + DownloadUserAgent string `json:"download_user_agent" required:"true" default:"AndroidDownloadManager/12 (Linux; U; Android 12; M2004J7AC Build/SP1A.210812.016)"` + + // 优先使用视频链接代替下载链接 + UseVideoUrl bool `json:"use_video_url"` + // 移除方式 + RemoveWay string `json:"remove_way" required:"true" type:"select" options:"trash,delete"` +} + +// GetIdentity 登录特征,用于判断是否重新登录 +func (i *ExpertAddition) GetIdentity() string { + hash := md5.New() + if i.LoginType == "refresh_token" { + hash.Write([]byte(i.RefreshToken)) + } else { + hash.Write([]byte(i.Username + i.Password)) + } + + if i.SignType == "captcha_sign" { + hash.Write([]byte(i.CaptchaSign + i.Timestamp)) + } else { + hash.Write([]byte(i.Algorithms)) + } + + hash.Write([]byte(i.DeviceID)) + hash.Write([]byte(i.ClientID)) + hash.Write([]byte(i.ClientSecret)) + hash.Write([]byte(i.ClientVersion)) + hash.Write([]byte(i.PackageName)) + return hex.EncodeToString(hash.Sum(nil)) +} + +type Addition struct { + driver.RootID + Username string `json:"username" required:"true"` + Password string `json:"password" required:"true"` + SafePassword string `json:"safe_password" required:"false"` // 超级保险箱密码 + CaptchaToken string `json:"captcha_token"` + UseVideoUrl bool `json:"use_video_url" default:"false"` + RemoveWay string `json:"remove_way" required:"true" type:"select" options:"trash,delete"` +} + +// GetIdentity 登录特征,用于判断是否重新登录 +func (i *Addition) GetIdentity() string { + return utils.GetMD5EncodeStr(i.Username + i.Password) +} + +var config = driver.Config{ + Name: "ThunderBrowser", + LocalSort: true, + OnlyProxy: true, +} + +var configExpert = driver.Config{ + Name: "ThunderBrowserExpert", + LocalSort: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &ThunderBrowser{} + }) + op.RegisterDriver(func() driver.Driver { + return &ThunderBrowserExpert{} + }) +} diff --git a/drivers/thunder_browser/types.go b/drivers/thunder_browser/types.go new file mode 100644 index 0000000000000000000000000000000000000000..774b34bb287cfe32419acb144f54688d7b0026ea --- /dev/null +++ b/drivers/thunder_browser/types.go @@ -0,0 +1,223 @@ +package thunder_browser + +import ( + "fmt" + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" +) + +type ErrResp struct { + ErrorCode int64 `json:"error_code"` + ErrorMsg string `json:"error"` + ErrorDescription string `json:"error_description"` + // ErrorDetails interface{} `json:"error_details"` +} + +func (e *ErrResp) IsError() bool { + return e.ErrorCode != 0 || e.ErrorMsg != "" || e.ErrorDescription != "" +} + +func (e *ErrResp) Error() string { + return fmt.Sprintf("ErrorCode: %d ,Error: %s ,ErrorDescription: %s ", e.ErrorCode, e.ErrorMsg, e.ErrorDescription) +} + +/* +* 验证码Token +**/ +type CaptchaTokenRequest struct { + Action string `json:"action"` + CaptchaToken string `json:"captcha_token"` + ClientID string `json:"client_id"` + DeviceID string `json:"device_id"` + Meta map[string]string `json:"meta"` + RedirectUri string `json:"redirect_uri"` +} + +type CaptchaTokenResponse struct { + CaptchaToken string `json:"captcha_token"` + ExpiresIn int64 `json:"expires_in"` + Url string `json:"url"` +} + +/* +* 登录 +**/ +type TokenResp struct { + TokenType string `json:"token_type"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + + Sub string `json:"sub"` + UserID string `json:"user_id"` + + Token string `json:"token"` // "超级保险箱" 访问Token +} + +func (t *TokenResp) GetToken() string { + return fmt.Sprint(t.TokenType, " ", t.AccessToken) +} + +// GetSpaceToken 获取"超级保险箱" 访问Token +func (t *TokenResp) GetSpaceToken() string { + return t.Token +} + +type SignInRequest struct { + CaptchaToken string `json:"captcha_token"` + + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + + Username string `json:"username"` + Password string `json:"password"` +} + +/* +* 文件 +**/ +type FileList struct { + Kind string `json:"kind"` + NextPageToken string `json:"next_page_token"` + Files []Files `json:"files"` + Version string `json:"version"` + VersionOutdated bool `json:"version_outdated"` + FolderType int8 +} + +type Link struct { + URL string `json:"url"` + Token string `json:"token"` + Expire time.Time `json:"expire"` + Type string `json:"type"` +} + +var _ model.Obj = (*Files)(nil) + +type Files struct { + Kind string `json:"kind"` + ID string `json:"id"` + ParentID string `json:"parent_id"` + Name string `json:"name"` + //UserID string `json:"user_id"` + Size string `json:"size"` + //Revision string `json:"revision"` + //FileExtension string `json:"file_extension"` + //MimeType string `json:"mime_type"` + //Starred bool `json:"starred"` + WebContentLink string `json:"web_content_link"` + CreatedTime CustomTime `json:"created_time"` + ModifiedTime CustomTime `json:"modified_time"` + IconLink string `json:"icon_link"` + ThumbnailLink string `json:"thumbnail_link"` + // Md5Checksum string `json:"md5_checksum"` + Hash string `json:"hash"` + // Links map[string]Link `json:"links"` + // Phase string `json:"phase"` + // Audit struct { + // Status string `json:"status"` + // Message string `json:"message"` + // Title string `json:"title"` + // } `json:"audit"` + Medias []struct { + //Category string `json:"category"` + //IconLink string `json:"icon_link"` + //IsDefault bool `json:"is_default"` + //IsOrigin bool `json:"is_origin"` + //IsVisible bool `json:"is_visible"` + Link Link `json:"link"` + //MediaID string `json:"media_id"` + //MediaName string `json:"media_name"` + //NeedMoreQuota bool `json:"need_more_quota"` + //Priority int `json:"priority"` + //RedirectLink string `json:"redirect_link"` + //ResolutionName string `json:"resolution_name"` + // Video struct { + // AudioCodec string `json:"audio_codec"` + // BitRate int `json:"bit_rate"` + // Duration int `json:"duration"` + // FrameRate int `json:"frame_rate"` + // Height int `json:"height"` + // VideoCodec string `json:"video_codec"` + // VideoType string `json:"video_type"` + // Width int `json:"width"` + // } `json:"video"` + // VipTypes []string `json:"vip_types"` + } `json:"medias"` + Trashed bool `json:"trashed"` + DeleteTime string `json:"delete_time"` + OriginalURL string `json:"original_url"` + //Params struct{} `json:"params"` + //OriginalFileIndex int `json:"original_file_index"` + //Space string `json:"space"` + //Apps []interface{} `json:"apps"` + //Writable bool `json:"writable"` + FolderType string `json:"folder_type"` + //Collection interface{} `json:"collection"` + FileType int8 +} + +func (c *Files) GetHash() utils.HashInfo { + return utils.NewHashInfo(hash_extend.GCID, c.Hash) +} + +func (c *Files) GetSize() int64 { size, _ := strconv.ParseInt(c.Size, 10, 64); return size } +func (c *Files) GetName() string { return c.Name } +func (c *Files) CreateTime() time.Time { return c.CreatedTime.Time } +func (c *Files) ModTime() time.Time { return c.ModifiedTime.Time } +func (c *Files) IsDir() bool { return c.Kind == FOLDER } +func (c *Files) GetID() string { return c.ID } +func (c *Files) GetPath() string { + // 对特殊文件进行特殊处理 + if c.FileType == ThunderDriveType { + return ThunderDriveFileID + } else if c.FileType == ThunderBrowserDriveSafeType { + return ThunderBrowserDriveSafeFileID + } + return "" +} +func (c *Files) Thumb() string { return c.ThumbnailLink } + +/* +* 上传 +**/ +type UploadTaskResponse struct { + UploadType string `json:"upload_type"` + + /*//UPLOAD_TYPE_FORM + Form struct { + //Headers struct{} `json:"headers"` + Kind string `json:"kind"` + Method string `json:"method"` + MultiParts struct { + OSSAccessKeyID string `json:"OSSAccessKeyId"` + Signature string `json:"Signature"` + Callback string `json:"callback"` + Key string `json:"key"` + Policy string `json:"policy"` + XUserData string `json:"x:user_data"` + } `json:"multi_parts"` + URL string `json:"url"` + } `json:"form"`*/ + + //UPLOAD_TYPE_RESUMABLE + Resumable struct { + Kind string `json:"kind"` + Params struct { + AccessKeyID string `json:"access_key_id"` + AccessKeySecret string `json:"access_key_secret"` + Bucket string `json:"bucket"` + Endpoint string `json:"endpoint"` + Expiration time.Time `json:"expiration"` + Key string `json:"key"` + SecurityToken string `json:"security_token"` + } `json:"params"` + Provider string `json:"provider"` + } `json:"resumable"` + + File Files `json:"file"` +} diff --git a/drivers/thunder_browser/util.go b/drivers/thunder_browser/util.go new file mode 100644 index 0000000000000000000000000000000000000000..fd8a4047b1b8dac25e3ef7e074682ad6b67cfeab --- /dev/null +++ b/drivers/thunder_browser/util.go @@ -0,0 +1,249 @@ +package thunder_browser + +import ( + "crypto/md5" + "crypto/sha1" + "encoding/hex" + "fmt" + "io" + "net/http" + "regexp" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +const ( + API_URL = "https://x-api-pan.xunlei.com/drive/v1" + FILE_API_URL = API_URL + "/files" + XLUSER_API_URL = "https://xluser-ssl.xunlei.com/v1" +) + +const ( + FOLDER = "drive#folder" + FILE = "drive#file" + RESUMABLE = "drive#resumable" +) + +const ( + UPLOAD_TYPE_UNKNOWN = "UPLOAD_TYPE_UNKNOWN" + //UPLOAD_TYPE_FORM = "UPLOAD_TYPE_FORM" + UPLOAD_TYPE_RESUMABLE = "UPLOAD_TYPE_RESUMABLE" + UPLOAD_TYPE_URL = "UPLOAD_TYPE_URL" +) + +const ( + ThunderDriveFileID = "XXXXXXXXXXXXXXXXXXXXXXXXXX" + ThunderBrowserDriveSafeFileID = "YYYYYYYYYYYYYYYYYYYYYYYYYY" + ThunderDriveFolderName = "迅雷云盘" + ThunderBrowserDriveSafeFolderName = "超级保险箱" + ThunderDriveType = 1 + ThunderBrowserDriveSafeType = 2 + ThunderDriveFolderType = "DEFAULT_ROOT" + ThunderBrowserDriveSafeFolderType = "BROWSER_SAFE" +) + +func GetAction(method string, url string) string { + urlpath := regexp.MustCompile(`://[^/]+((/[^/\s?#]+)*)`).FindStringSubmatch(url)[1] + return method + ":" + urlpath +} + +type Common struct { + client *resty.Client + + captchaToken string + + // 签名相关,二选一 + Algorithms []string + Timestamp, CaptchaSign string + + // 必要值,签名相关 + DeviceID string + ClientID string + ClientSecret string + ClientVersion string + PackageName string + UserAgent string + DownloadUserAgent string + UseVideoUrl bool + RemoveWay string + + // 验证码token刷新成功回调 + refreshCTokenCk func(token string) +} + +func (c *Common) SetCaptchaToken(captchaToken string) { + c.captchaToken = captchaToken +} +func (c *Common) GetCaptchaToken() string { + return c.captchaToken +} + +// RefreshCaptchaTokenAtLogin 刷新验证码token(登录后) +func (c *Common) RefreshCaptchaTokenAtLogin(action, userID string) error { + metas := map[string]string{ + "client_version": c.ClientVersion, + "package_name": c.PackageName, + "user_id": userID, + } + metas["timestamp"], metas["captcha_sign"] = c.GetCaptchaSign() + return c.refreshCaptchaToken(action, metas) +} + +// RefreshCaptchaTokenInLogin 刷新验证码token(登录时) +func (c *Common) RefreshCaptchaTokenInLogin(action, username string) error { + metas := make(map[string]string) + if ok, _ := regexp.MatchString(`\w+([-+.]\w+)*@\w+([-.]\w+)*\.\w+([-.]\w+)*`, username); ok { + metas["email"] = username + } else if len(username) >= 11 && len(username) <= 18 { + metas["phone_number"] = username + } else { + metas["username"] = username + } + return c.refreshCaptchaToken(action, metas) +} + +// GetCaptchaSign 获取验证码签名 +func (c *Common) GetCaptchaSign() (timestamp, sign string) { + if len(c.Algorithms) == 0 { + return c.Timestamp, c.CaptchaSign + } + timestamp = fmt.Sprint(time.Now().UnixMilli()) + str := fmt.Sprint(c.ClientID, c.ClientVersion, c.PackageName, c.DeviceID, timestamp) + for _, algorithm := range c.Algorithms { + str = utils.GetMD5EncodeStr(str + algorithm) + } + sign = "1." + str + return +} + +// 刷新验证码token +func (c *Common) refreshCaptchaToken(action string, metas map[string]string) error { + param := CaptchaTokenRequest{ + Action: action, + CaptchaToken: c.captchaToken, + ClientID: c.ClientID, + DeviceID: c.DeviceID, + Meta: metas, + RedirectUri: "xlaccsdk01://xunlei.com/callback?state=harbor", + } + var e ErrResp + var resp CaptchaTokenResponse + _, err := c.Request(XLUSER_API_URL+"/shield/captcha/init", http.MethodPost, func(req *resty.Request) { + req.SetError(&e).SetBody(param) + }, &resp) + + if err != nil { + return err + } + + if e.IsError() { + return &e + } + + if resp.Url != "" { + return fmt.Errorf(`need verify: Click Here`, resp.Url) + } + + if resp.CaptchaToken == "" { + return fmt.Errorf("empty captchaToken") + } + + if c.refreshCTokenCk != nil { + c.refreshCTokenCk(resp.CaptchaToken) + } + c.SetCaptchaToken(resp.CaptchaToken) + return nil +} + +// Request 只有基础信息的请求 +func (c *Common) Request(url, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := c.client.R().SetHeaders(map[string]string{ + "user-agent": c.UserAgent, + "accept": "application/json;charset=UTF-8", + "x-device-id": c.DeviceID, + "x-client-id": c.ClientID, + "x-client-version": c.ClientVersion, + }) + + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + + var erron ErrResp + utils.Json.Unmarshal(res.Body(), &erron) + if erron.IsError() { + return nil, &erron + } + + return res.Body(), nil +} + +// 计算文件Gcid +func getGcid(r io.Reader, size int64) (string, error) { + calcBlockSize := func(j int64) int64 { + var psize int64 = 0x40000 + for float64(j)/float64(psize) > 0x200 && psize < 0x200000 { + psize = psize << 1 + } + return psize + } + + hash1 := sha1.New() + hash2 := sha1.New() + readSize := calcBlockSize(size) + for { + hash2.Reset() + if n, err := utils.CopyWithBufferN(hash2, r, readSize); err != nil && n == 0 { + if err != io.EOF { + return "", err + } + break + } + hash1.Write(hash2.Sum(nil)) + } + return hex.EncodeToString(hash1.Sum(nil)), nil +} + +type CustomTime struct { + time.Time +} + +const timeFormat = time.RFC3339 + +func (ct *CustomTime) UnmarshalJSON(b []byte) error { + str := string(b) + if str == `""` { + *ct = CustomTime{Time: time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)} + return nil + } + + t, err := time.Parse(`"`+timeFormat+`"`, str) + if err != nil { + return err + } + *ct = CustomTime{Time: t} + return nil +} + +// EncryptPassword 超级保险箱 加密 +func EncryptPassword(password string) string { + if password == "" { + return "" + } + // 将字符串转换为字节数组 + byteData := []byte(password) + // 计算MD5哈希值 + hash := md5.Sum(byteData) + // 将哈希值转换为十六进制字符串 + return hex.EncodeToString(hash[:]) +} diff --git a/drivers/thunderx/driver.go b/drivers/thunderx/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..06e0cca5978891799157331ece89e29fce277812 --- /dev/null +++ b/drivers/thunderx/driver.go @@ -0,0 +1,605 @@ +package thunderx + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/go-resty/resty/v2" +) + +type ThunderX struct { + *XunLeiXCommon + model.Storage + Addition + + identity string +} + +func (x *ThunderX) Config() driver.Config { + return config +} + +func (x *ThunderX) GetAddition() driver.Additional { + return &x.Addition +} + +func (x *ThunderX) Init(ctx context.Context) (err error) { + // 初始化所需参数 + if x.XunLeiXCommon == nil { + x.XunLeiXCommon = &XunLeiXCommon{ + Common: &Common{ + client: base.NewRestyClient(), + Algorithms: Algorithms, + DeviceID: utils.GetMD5EncodeStr(x.Username + x.Password), + ClientID: ClientID, + ClientSecret: ClientSecret, + ClientVersion: ClientVersion, + PackageName: PackageName, + UserAgent: BuildCustomUserAgent(utils.GetMD5EncodeStr(x.Username+x.Password), ClientID, PackageName, SdkVersion, ClientVersion, PackageName, ""), + DownloadUserAgent: DownloadUserAgent, + UseVideoUrl: x.UseVideoUrl, + UseProxy: x.UseProxy, + //下载地址是否使用代理 + UseUrlProxy: x.UseUrlProxy, + ProxyUrl: x.ProxyUrl, + + refreshCTokenCk: func(token string) { + x.CaptchaToken = token + op.MustSaveDriverStorage(x) + }, + }, + refreshTokenFunc: func() error { + // 通过RefreshToken刷新 + token, err := x.RefreshToken(x.TokenResp.RefreshToken) + if err != nil { + // 重新登录 + token, err = x.Login(x.Username, x.Password) + if err != nil { + x.GetStorage().SetStatus(fmt.Sprintf("%+v", err.Error())) + if token.UserID != "" { + x.SetUserID(token.UserID) + x.UserAgent = BuildCustomUserAgent(utils.GetMD5EncodeStr(x.Username+x.Password), ClientID, PackageName, SdkVersion, ClientVersion, PackageName, token.UserID) + } + op.MustSaveDriverStorage(x) + } + } + x.SetTokenResp(token) + return err + }, + } + } + + // 自定义验证码token + ctoken := strings.TrimSpace(x.CaptchaToken) + if ctoken != "" { + x.SetCaptchaToken(ctoken) + } + if x.DeviceID == "" { + x.SetDeviceID(utils.GetMD5EncodeStr(x.Username + x.Password)) + } + + x.XunLeiXCommon.UseVideoUrl = x.UseVideoUrl + x.Addition.RootFolderID = x.RootFolderID + // 防止重复登录 + identity := x.GetIdentity() + if x.identity != identity || !x.IsLogin() { + x.identity = identity + // 登录 + token, err := x.Login(x.Username, x.Password) + if err != nil { + return err + } + x.SetTokenResp(token) + if token.UserID != "" { + x.SetUserID(token.UserID) + x.UserAgent = BuildCustomUserAgent(x.DeviceID, ClientID, PackageName, SdkVersion, ClientVersion, PackageName, token.UserID) + } + } + return nil +} + +func (x *ThunderX) Drop(ctx context.Context) error { + return nil +} + +type ThunderXExpert struct { + *XunLeiXCommon + model.Storage + ExpertAddition + + identity string +} + +func (x *ThunderXExpert) Config() driver.Config { + return configExpert +} + +func (x *ThunderXExpert) GetAddition() driver.Additional { + return &x.ExpertAddition +} + +func (x *ThunderXExpert) Init(ctx context.Context) (err error) { + // 防止重复登录 + identity := x.GetIdentity() + if identity != x.identity || !x.IsLogin() { + x.identity = identity + x.XunLeiXCommon = &XunLeiXCommon{ + Common: &Common{ + client: base.NewRestyClient(), + + DeviceID: func() string { + if len(x.DeviceID) != 32 { + if x.LoginType == "user" { + return utils.GetMD5EncodeStr(x.Username + x.Password) + } + return utils.GetMD5EncodeStr(x.ExpertAddition.RefreshToken) + } + return x.DeviceID + }(), + ClientID: x.ClientID, + ClientSecret: x.ClientSecret, + ClientVersion: x.ClientVersion, + PackageName: x.PackageName, + UserAgent: func() string { + if x.ExpertAddition.UserAgent != "" { + return x.ExpertAddition.UserAgent + } + if x.LoginType == "user" { + return BuildCustomUserAgent(utils.GetMD5EncodeStr(x.Username+x.Password), ClientID, PackageName, SdkVersion, ClientVersion, PackageName, "") + } + return BuildCustomUserAgent(utils.GetMD5EncodeStr(x.ExpertAddition.RefreshToken), ClientID, PackageName, SdkVersion, ClientVersion, PackageName, "") + }(), + DownloadUserAgent: func() string { + if x.ExpertAddition.DownloadUserAgent != "" { + return x.ExpertAddition.DownloadUserAgent + } + return DownloadUserAgent + }(), + UseVideoUrl: x.UseVideoUrl, + UseProxy: x.ExpertAddition.UseProxy, + //下载地址是否使用代理 + UseUrlProxy: x.ExpertAddition.UseUrlProxy, + ProxyUrl: x.ExpertAddition.ProxyUrl, + refreshCTokenCk: func(token string) { + x.CaptchaToken = token + op.MustSaveDriverStorage(x) + }, + }, + } + + if x.ExpertAddition.CaptchaToken != "" { + x.SetCaptchaToken(x.ExpertAddition.CaptchaToken) + op.MustSaveDriverStorage(x) + } + if x.Common.DeviceID != "" { + x.ExpertAddition.DeviceID = x.Common.DeviceID + op.MustSaveDriverStorage(x) + } + if x.Common.DownloadUserAgent != "" { + x.ExpertAddition.DownloadUserAgent = x.Common.DownloadUserAgent + op.MustSaveDriverStorage(x) + } + x.XunLeiXCommon.UseVideoUrl = x.UseVideoUrl + x.ExpertAddition.RootFolderID = x.RootFolderID + // 签名方法 + if x.SignType == "captcha_sign" { + x.Common.Timestamp = x.Timestamp + x.Common.CaptchaSign = x.CaptchaSign + } else { + x.Common.Algorithms = strings.Split(x.Algorithms, ",") + } + + // 登录方式 + if x.LoginType == "refresh_token" { + // 通过RefreshToken登录 + token, err := x.XunLeiXCommon.RefreshToken(x.ExpertAddition.RefreshToken) + if err != nil { + return err + } + x.SetTokenResp(token) + // 刷新token方法 + x.SetRefreshTokenFunc(func() error { + token, err := x.XunLeiXCommon.RefreshToken(x.TokenResp.RefreshToken) + if err != nil { + x.GetStorage().SetStatus(fmt.Sprintf("%+v", err.Error())) + } + x.SetTokenResp(token) + op.MustSaveDriverStorage(x) + return err + }) + } else { + // 通过用户密码登录 + token, err := x.Login(x.Username, x.Password) + if err != nil { + return err + } + x.SetTokenResp(token) + x.SetRefreshTokenFunc(func() error { + token, err := x.XunLeiXCommon.RefreshToken(x.TokenResp.RefreshToken) + if err != nil { + token, err = x.Login(x.Username, x.Password) + if err != nil { + x.GetStorage().SetStatus(fmt.Sprintf("%+v", err.Error())) + } + } + x.SetTokenResp(token) + op.MustSaveDriverStorage(x) + return err + }) + } + // 更新 UserAgent + if x.TokenResp.UserID != "" { + x.ExpertAddition.UserAgent = BuildCustomUserAgent(x.ExpertAddition.DeviceID, ClientID, PackageName, SdkVersion, ClientVersion, PackageName, x.TokenResp.UserID) + x.SetUserAgent(x.ExpertAddition.UserAgent) + op.MustSaveDriverStorage(x) + } + } else { + // 仅修改验证码token + if x.CaptchaToken != "" { + x.SetCaptchaToken(x.CaptchaToken) + } + x.XunLeiXCommon.UserAgent = x.ExpertAddition.UserAgent + x.XunLeiXCommon.DownloadUserAgent = x.ExpertAddition.UserAgent + x.XunLeiXCommon.UseVideoUrl = x.UseVideoUrl + x.ExpertAddition.RootFolderID = x.RootFolderID + } + return nil +} + +func (x *ThunderXExpert) Drop(ctx context.Context) error { + return nil +} + +func (x *ThunderXExpert) SetTokenResp(token *TokenResp) { + x.XunLeiXCommon.SetTokenResp(token) + if token != nil { + x.ExpertAddition.RefreshToken = token.RefreshToken + } +} + +type XunLeiXCommon struct { + *Common + *TokenResp // 登录信息 + + refreshTokenFunc func() error +} + +func (xc *XunLeiXCommon) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + return xc.getFiles(ctx, dir.GetID()) +} + +func (xc *XunLeiXCommon) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var lFile Files + _, err := xc.Request(FILE_API_URL+"/{fileID}", http.MethodGet, func(r *resty.Request) { + r.SetContext(ctx) + r.SetPathParam("fileID", file.GetID()) + //r.SetQueryParam("space", "") + }, &lFile) + if err != nil { + return nil, err + } + + link := &model.Link{ + URL: lFile.WebContentLink, + Header: http.Header{ + "User-Agent": {xc.DownloadUserAgent}, + }, + } + + if xc.UseVideoUrl { + for _, media := range lFile.Medias { + if media.Link.URL != "" { + link.URL = media.Link.URL + break + } + } + } + + if xc.UseUrlProxy { + if strings.HasSuffix(xc.ProxyUrl, "/") { + link.URL = xc.ProxyUrl + link.URL + } else { + link.URL = xc.ProxyUrl + "/" + link.URL + } + } + + /* + strs := regexp.MustCompile(`e=([0-9]*)`).FindStringSubmatch(lFile.WebContentLink) + if len(strs) == 2 { + timestamp, err := strconv.ParseInt(strs[1], 10, 64) + if err == nil { + expired := time.Duration(timestamp-time.Now().Unix()) * time.Second + link.Expiration = &expired + } + } + */ + return link, nil +} + +func (xc *XunLeiXCommon) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(&base.Json{ + "kind": FOLDER, + "name": dirName, + "parent_id": parentDir.GetID(), + }) + }, nil) + return err +} + +func (xc *XunLeiXCommon) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + _, err := xc.Request(FILE_API_URL+":batchMove", http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(&base.Json{ + "to": base.Json{"parent_id": dstDir.GetID()}, + "ids": []string{srcObj.GetID()}, + }) + }, nil) + return err +} + +func (xc *XunLeiXCommon) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + _, err := xc.Request(FILE_API_URL+"/{fileID}", http.MethodPatch, func(r *resty.Request) { + r.SetContext(ctx) + r.SetPathParam("fileID", srcObj.GetID()) + r.SetBody(&base.Json{"name": newName}) + }, nil) + return err +} + +func (xc *XunLeiXCommon) Offline(ctx context.Context, args model.OtherArgs) (interface{}, error) { + _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetHeaders(map[string]string{ + "X-Device-Id": xc.DeviceID, + "User-Agent": xc.UserAgent, + "Peer-Id": xc.DeviceID, + "client_id": xc.ClientID, + "x-client-id": xc.ClientID, + "X-Guid": xc.DeviceID, + }) + r.SetBody(&base.Json{ + "kind": "drive#file", + "name": "", + "parent_id": args.Obj.GetID(), + "upload_type": "UPLOAD_TYPE_URL", + "url": &base.Json{ + "url": args.Data, + "params": "{}", + "parent_id": args.Obj.GetID(), + }, + }) + }, nil) + if err != nil { + return nil, err + } + return "ok", nil +} + +func (xc *XunLeiXCommon) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + _, err := xc.Request(FILE_API_URL+":batchCopy", http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(&base.Json{ + "to": base.Json{"parent_id": dstDir.GetID()}, + "ids": []string{srcObj.GetID()}, + }) + }, nil) + return err +} + +func (xc *XunLeiXCommon) Remove(ctx context.Context, obj model.Obj) error { + _, err := xc.Request(FILE_API_URL+"/{fileID}/trash", http.MethodPatch, func(r *resty.Request) { + r.SetContext(ctx) + r.SetPathParam("fileID", obj.GetID()) + r.SetBody("{}") + }, nil) + return err +} + +func (xc *XunLeiXCommon) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + hi := stream.GetHash() + gcid := hi.GetHash(hash_extend.GCID) + if len(gcid) < hash_extend.GCID.Width { + tFile, err := stream.CacheFullInTempFile() + if err != nil { + return err + } + + gcid, err = utils.HashFile(hash_extend.GCID, tFile, stream.GetSize()) + if err != nil { + return err + } + } + + var resp UploadTaskResponse + _, err := xc.Request(FILE_API_URL, http.MethodPost, func(r *resty.Request) { + r.SetContext(ctx) + r.SetBody(&base.Json{ + "kind": FILE, + "parent_id": dstDir.GetID(), + "name": stream.GetName(), + "size": stream.GetSize(), + "hash": gcid, + "upload_type": UPLOAD_TYPE_RESUMABLE, + }) + }, &resp) + if err != nil { + return err + } + + param := resp.Resumable.Params + if resp.UploadType == UPLOAD_TYPE_RESUMABLE { + param.Endpoint = strings.TrimLeft(param.Endpoint, param.Bucket+".") + s, err := session.NewSession(&aws.Config{ + Credentials: credentials.NewStaticCredentials(param.AccessKeyID, param.AccessKeySecret, param.SecurityToken), + Region: aws.String("xunlei"), + Endpoint: aws.String(param.Endpoint), + }) + if err != nil { + return err + } + uploader := s3manager.NewUploader(s) + if stream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { + uploader.PartSize = stream.GetSize() / (s3manager.MaxUploadParts - 1) + } + _, err = uploader.UploadWithContext(ctx, &s3manager.UploadInput{ + Bucket: aws.String(param.Bucket), + Key: aws.String(param.Key), + Expires: aws.Time(param.Expiration), + Body: stream, + }) + return err + } + return nil +} + +func (xc *XunLeiXCommon) getFiles(ctx context.Context, folderId string) ([]model.Obj, error) { + files := make([]model.Obj, 0) + var pageToken string + for { + var fileList FileList + _, err := xc.Request(FILE_API_URL, http.MethodGet, func(r *resty.Request) { + r.SetContext(ctx) + r.SetQueryParams(map[string]string{ + "space": "", + "__type": "drive", + "refresh": "true", + "__sync": "true", + "parent_id": folderId, + "page_token": pageToken, + "with_audit": "true", + "limit": "100", + "filters": `{"phase":{"eq":"PHASE_TYPE_COMPLETE"},"trashed":{"eq":false}}`, + }) + }, &fileList) + if err != nil { + return nil, err + } + + for i := 0; i < len(fileList.Files); i++ { + files = append(files, &fileList.Files[i]) + } + + if fileList.NextPageToken == "" { + break + } + pageToken = fileList.NextPageToken + } + return files, nil +} + +// SetRefreshTokenFunc 设置刷新Token的方法 +func (xc *XunLeiXCommon) SetRefreshTokenFunc(fn func() error) { + xc.refreshTokenFunc = fn +} + +// SetTokenResp 设置Token +func (xc *XunLeiXCommon) SetTokenResp(tr *TokenResp) { + xc.TokenResp = tr +} + +// Request 携带Authorization和CaptchaToken的请求 +func (xc *XunLeiXCommon) Request(url string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + data, err := xc.Common.Request(url, method, func(req *resty.Request) { + req.SetHeaders(map[string]string{ + "Authorization": xc.Token(), + "X-Captcha-Token": xc.GetCaptchaToken(), + }) + if callback != nil { + callback(req) + } + }, resp) + + errResp, ok := err.(*ErrResp) + if !ok { + return nil, err + } + + switch errResp.ErrorCode { + case 0: + return data, nil + case 4122, 4121, 10, 16: + if xc.refreshTokenFunc != nil { + if err = xc.refreshTokenFunc(); err == nil { + break + } + } + return nil, err + case 9: // 验证码token过期 + if err = xc.RefreshCaptchaTokenAtLogin(GetAction(method, url), xc.UserID); err != nil { + return nil, err + } + default: + return nil, err + } + return xc.Request(url, method, callback, resp) +} + +// RefreshToken 刷新Token +func (xc *XunLeiXCommon) RefreshToken(refreshToken string) (*TokenResp, error) { + var resp TokenResp + _, err := xc.Common.Request(XLUSER_API_URL+"/auth/token", http.MethodPost, func(req *resty.Request) { + req.SetBody(&base.Json{ + "grant_type": "refresh_token", + "refresh_token": refreshToken, + "client_id": xc.ClientID, + "client_secret": xc.ClientSecret, + }) + }, &resp) + if err != nil { + return nil, err + } + + if resp.RefreshToken == "" { + return nil, errs.EmptyToken + } + resp.UserID = resp.Sub + return &resp, nil +} + +// Login 登录 +func (xc *XunLeiXCommon) Login(username, password string) (*TokenResp, error) { + url := XLUSER_API_URL + "/auth/signin" + err := xc.RefreshCaptchaTokenInLogin(GetAction(http.MethodPost, url), username) + if err != nil { + return nil, err + } + + var resp TokenResp + _, err = xc.Common.Request(url, http.MethodPost, func(req *resty.Request) { + req.SetBody(&SignInRequest{ + CaptchaToken: xc.GetCaptchaToken(), + ClientID: xc.ClientID, + ClientSecret: xc.ClientSecret, + Username: username, + Password: password, + }) + }, &resp) + if err != nil { + return nil, err + } + resp.UserID = resp.Sub + return &resp, nil +} + +func (xc *XunLeiXCommon) IsLogin() bool { + if xc.TokenResp == nil { + return false + } + _, err := xc.Request(XLUSER_API_URL+"/user/me", http.MethodGet, nil, nil) + return err == nil +} diff --git a/drivers/thunderx/meta.go b/drivers/thunderx/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..3dbd00e8fed51e6db1cc4b91a8e8580d4da66fc2 --- /dev/null +++ b/drivers/thunderx/meta.go @@ -0,0 +1,113 @@ +package thunderx + +import ( + "crypto/md5" + "encoding/hex" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" +) + +// 高级设置 +type ExpertAddition struct { + driver.RootID + + LoginType string `json:"login_type" type:"select" options:"user,refresh_token" default:"user"` + SignType string `json:"sign_type" type:"select" options:"algorithms,captcha_sign" default:"algorithms"` + + // 登录方式1 + Username string `json:"username" required:"true" help:"login type is user,this is required"` + Password string `json:"password" required:"true" help:"login type is user,this is required"` + // 登录方式2 + RefreshToken string `json:"refresh_token" required:"true" help:"login type is refresh_token,this is required"` + + // 签名方法1 + Algorithms string `json:"algorithms" required:"true" help:"sign type is algorithms,this is required" default:"kVy0WbPhiE4v6oxXZ88DvoA3Q,lON/AUoZKj8/nBtcE85mVbkOaVdVa,rLGffQrfBKH0BgwQ33yZofvO3Or,FO6HWqw,GbgvyA2,L1NU9QvIQIH7DTRt,y7llk4Y8WfYflt6,iuDp1WPbV3HRZudZtoXChxH4HNVBX5ZALe,8C28RTXmVcco0,X5Xh,7xe25YUgfGgD0xW3ezFS,,CKCR,8EmDjBo6h3eLaK7U6vU2Qys0NsMx,t2TeZBXKqbdP09Arh9C3"` + // 签名方法2 + CaptchaSign string `json:"captcha_sign" required:"true" help:"sign type is captcha_sign,this is required"` + Timestamp string `json:"timestamp" required:"true" help:"sign type is captcha_sign,this is required"` + + // 验证码 + CaptchaToken string `json:"captcha_token"` + + // 必要且影响登录,由签名决定 + DeviceID string `json:"device_id" required:"false" default:""` + ClientID string `json:"client_id" required:"true" default:"ZQL_zwA4qhHcoe_2"` + ClientSecret string `json:"client_secret" required:"true" default:"Og9Vr1L8Ee6bh0olFxFDRg"` + ClientVersion string `json:"client_version" required:"true" default:"1.06.0.2132"` + PackageName string `json:"package_name" required:"true" default:"com.thunder.downloader"` + + ////不影响登录,影响下载速度 + UserAgent string `json:"user_agent" required:"false" default:""` + DownloadUserAgent string `json:"download_user_agent" required:"false" default:""` + + //优先使用视频链接代替下载链接 + UseVideoUrl bool `json:"use_video_url" default:"true"` + //是否使用代理 + UseProxy bool `json:"use_proxy"` + //下载地址是否使用代理 + UseUrlProxy bool `json:"use_url_proxy"` + ProxyUrl string `json:"proxy_url" default:""` +} + +// 登录特征,用于判断是否重新登录 +func (i *ExpertAddition) GetIdentity() string { + hash := md5.New() + if i.LoginType == "refresh_token" { + hash.Write([]byte(i.RefreshToken)) + } else { + hash.Write([]byte(i.Username + i.Password)) + } + + if i.SignType == "captcha_sign" { + hash.Write([]byte(i.CaptchaSign + i.Timestamp)) + } else { + hash.Write([]byte(i.Algorithms)) + } + + hash.Write([]byte(i.DeviceID)) + hash.Write([]byte(i.ClientID)) + hash.Write([]byte(i.ClientSecret)) + hash.Write([]byte(i.ClientVersion)) + hash.Write([]byte(i.PackageName)) + return hex.EncodeToString(hash.Sum(nil)) +} + +type Addition struct { + driver.RootID + Username string `json:"username" required:"true"` + Password string `json:"password" required:"true"` + CaptchaToken string `json:"captcha_token"` + UseVideoUrl bool `json:"use_video_url" default:"true"` + //是否使用代理 + UseProxy bool `json:"use_proxy"` + //下载地址是否使用代理 + UseUrlProxy bool `json:"use_url_proxy"` + ProxyUrl string `json:"proxy_url" default:""` +} + +// 登录特征,用于判断是否重新登录 +func (i *Addition) GetIdentity() string { + return utils.GetMD5EncodeStr(i.Username + i.Password) +} + +var config = driver.Config{ + Name: "ThunderX", + LocalSort: true, + OnlyProxy: false, +} + +var configExpert = driver.Config{ + Name: "ThunderXExpert", + LocalSort: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &ThunderX{} + }) + op.RegisterDriver(func() driver.Driver { + return &ThunderXExpert{} + }) +} diff --git a/drivers/thunderx/types.go b/drivers/thunderx/types.go new file mode 100644 index 0000000000000000000000000000000000000000..77cfa0f2415c50ed640a4aac3f4fa44c425034b2 --- /dev/null +++ b/drivers/thunderx/types.go @@ -0,0 +1,206 @@ +package thunderx + +import ( + "fmt" + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" +) + +type ErrResp struct { + ErrorCode int64 `json:"error_code"` + ErrorMsg string `json:"error"` + ErrorDescription string `json:"error_description"` + // ErrorDetails interface{} `json:"error_details"` +} + +func (e *ErrResp) IsError() bool { + return e.ErrorCode != 0 || e.ErrorMsg != "" || e.ErrorDescription != "" +} + +func (e *ErrResp) Error() string { + return fmt.Sprintf("ErrorCode: %d ,Error: %s ,ErrorDescription: %s ", e.ErrorCode, e.ErrorMsg, e.ErrorDescription) +} + +/* +* 验证码Token +**/ +type CaptchaTokenRequest struct { + Action string `json:"action"` + CaptchaToken string `json:"captcha_token"` + ClientID string `json:"client_id"` + DeviceID string `json:"device_id"` + Meta map[string]string `json:"meta"` + RedirectUri string `json:"redirect_uri"` +} + +type CaptchaTokenResponse struct { + CaptchaToken string `json:"captcha_token"` + ExpiresIn int64 `json:"expires_in"` + Url string `json:"url"` +} + +/* +* 登录 +**/ +type TokenResp struct { + TokenType string `json:"token_type"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + + Sub string `json:"sub"` + UserID string `json:"user_id"` +} + +func (t *TokenResp) Token() string { + return fmt.Sprint(t.TokenType, " ", t.AccessToken) +} + +type SignInRequest struct { + CaptchaToken string `json:"captcha_token"` + + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + + Username string `json:"username"` + Password string `json:"password"` +} + +/* +* 文件 +**/ +type FileList struct { + Kind string `json:"kind"` + NextPageToken string `json:"next_page_token"` + Files []Files `json:"files"` + Version string `json:"version"` + VersionOutdated bool `json:"version_outdated"` +} + +type Link struct { + URL string `json:"url"` + Token string `json:"token"` + Expire time.Time `json:"expire"` + Type string `json:"type"` +} + +var _ model.Obj = (*Files)(nil) + +type Files struct { + Kind string `json:"kind"` + ID string `json:"id"` + ParentID string `json:"parent_id"` + Name string `json:"name"` + //UserID string `json:"user_id"` + Size string `json:"size"` + //Revision string `json:"revision"` + //FileExtension string `json:"file_extension"` + //MimeType string `json:"mime_type"` + //Starred bool `json:"starred"` + WebContentLink string `json:"web_content_link"` + CreatedTime time.Time `json:"created_time"` + ModifiedTime time.Time `json:"modified_time"` + IconLink string `json:"icon_link"` + ThumbnailLink string `json:"thumbnail_link"` + // Md5Checksum string `json:"md5_checksum"` + Hash string `json:"hash"` + // Links map[string]Link `json:"links"` + // Phase string `json:"phase"` + // Audit struct { + // Status string `json:"status"` + // Message string `json:"message"` + // Title string `json:"title"` + // } `json:"audit"` + Medias []struct { + //Category string `json:"category"` + //IconLink string `json:"icon_link"` + //IsDefault bool `json:"is_default"` + //IsOrigin bool `json:"is_origin"` + //IsVisible bool `json:"is_visible"` + Link Link `json:"link"` + //MediaID string `json:"media_id"` + //MediaName string `json:"media_name"` + //NeedMoreQuota bool `json:"need_more_quota"` + //Priority int `json:"priority"` + //RedirectLink string `json:"redirect_link"` + //ResolutionName string `json:"resolution_name"` + // Video struct { + // AudioCodec string `json:"audio_codec"` + // BitRate int `json:"bit_rate"` + // Duration int `json:"duration"` + // FrameRate int `json:"frame_rate"` + // Height int `json:"height"` + // VideoCodec string `json:"video_codec"` + // VideoType string `json:"video_type"` + // Width int `json:"width"` + // } `json:"video"` + // VipTypes []string `json:"vip_types"` + } `json:"medias"` + Trashed bool `json:"trashed"` + DeleteTime string `json:"delete_time"` + OriginalURL string `json:"original_url"` + //Params struct{} `json:"params"` + //OriginalFileIndex int `json:"original_file_index"` + //Space string `json:"space"` + //Apps []interface{} `json:"apps"` + //Writable bool `json:"writable"` + //FolderType string `json:"folder_type"` + //Collection interface{} `json:"collection"` +} + +func (c *Files) GetHash() utils.HashInfo { + return utils.NewHashInfo(hash_extend.GCID, c.Hash) +} + +func (c *Files) GetSize() int64 { size, _ := strconv.ParseInt(c.Size, 10, 64); return size } +func (c *Files) GetName() string { return c.Name } +func (c *Files) CreateTime() time.Time { return c.CreatedTime } +func (c *Files) ModTime() time.Time { return c.ModifiedTime } +func (c *Files) IsDir() bool { return c.Kind == FOLDER } +func (c *Files) GetID() string { return c.ID } +func (c *Files) GetPath() string { return "" } +func (c *Files) Thumb() string { return c.ThumbnailLink } + +/* +* 上传 +**/ +type UploadTaskResponse struct { + UploadType string `json:"upload_type"` + + /*//UPLOAD_TYPE_FORM + Form struct { + //Headers struct{} `json:"headers"` + Kind string `json:"kind"` + Method string `json:"method"` + MultiParts struct { + OSSAccessKeyID string `json:"OSSAccessKeyId"` + Signature string `json:"Signature"` + Callback string `json:"callback"` + Key string `json:"key"` + Policy string `json:"policy"` + XUserData string `json:"x:user_data"` + } `json:"multi_parts"` + URL string `json:"url"` + } `json:"form"`*/ + + //UPLOAD_TYPE_RESUMABLE + Resumable struct { + Kind string `json:"kind"` + Params struct { + AccessKeyID string `json:"access_key_id"` + AccessKeySecret string `json:"access_key_secret"` + Bucket string `json:"bucket"` + Endpoint string `json:"endpoint"` + Expiration time.Time `json:"expiration"` + Key string `json:"key"` + SecurityToken string `json:"security_token"` + } `json:"params"` + Provider string `json:"provider"` + } `json:"resumable"` + + File Files `json:"file"` +} diff --git a/drivers/thunderx/util.go b/drivers/thunderx/util.go new file mode 100644 index 0000000000000000000000000000000000000000..d2a8b2addb85d03127ed860d71c4918815f775ae --- /dev/null +++ b/drivers/thunderx/util.go @@ -0,0 +1,312 @@ +package thunderx + +import ( + "crypto/md5" + "crypto/sha1" + "encoding/hex" + "fmt" + "io" + "net/http" + "regexp" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +const ( + API_URL = "https://api-pan.xunleix.com/drive/v1" + FILE_API_URL = API_URL + "/files" + XLUSER_API_URL = "https://xluser-ssl.xunleix.com/v1" +) + +var Algorithms = []string{ + "kVy0WbPhiE4v6oxXZ88DvoA3Q", + "lON/AUoZKj8/nBtcE85mVbkOaVdVa", + "rLGffQrfBKH0BgwQ33yZofvO3Or", + "FO6HWqw", + "GbgvyA2", + "L1NU9QvIQIH7DTRt", + "y7llk4Y8WfYflt6", + "iuDp1WPbV3HRZudZtoXChxH4HNVBX5ZALe", + "8C28RTXmVcco0", + "X5Xh", + "7xe25YUgfGgD0xW3ezFS", + "", + "CKCR", + "8EmDjBo6h3eLaK7U6vU2Qys0NsMx", + "t2TeZBXKqbdP09Arh9C3", +} + +const ( + ClientID = "ZQL_zwA4qhHcoe_2" + ClientSecret = "Og9Vr1L8Ee6bh0olFxFDRg" + ClientVersion = "1.06.0.2132" + PackageName = "com.thunder.downloader" + DownloadUserAgent = "Dalvik/2.1.0 (Linux; U; Android 13; M2004J7AC Build/SP1A.210812.016)" + SdkVersion = "2.0.3.203100 " +) + +const ( + FOLDER = "drive#folder" + FILE = "drive#file" + RESUMABLE = "drive#resumable" +) + +const ( + UPLOAD_TYPE_UNKNOWN = "UPLOAD_TYPE_UNKNOWN" + //UPLOAD_TYPE_FORM = "UPLOAD_TYPE_FORM" + UPLOAD_TYPE_RESUMABLE = "UPLOAD_TYPE_RESUMABLE" + UPLOAD_TYPE_URL = "UPLOAD_TYPE_URL" +) + +func GetAction(method string, url string) string { + urlpath := regexp.MustCompile(`://[^/]+((/[^/\s?#]+)*)`).FindStringSubmatch(url)[1] + return method + ":" + urlpath +} + +type Common struct { + client *resty.Client + + captchaToken string + userID string + // 签名相关,二选一 + Algorithms []string + Timestamp, CaptchaSign string + + // 必要值,签名相关 + DeviceID string + ClientID string + ClientSecret string + ClientVersion string + PackageName string + UserAgent string + DownloadUserAgent string + UseVideoUrl bool + UseProxy bool + //下载地址是否使用代理 + UseUrlProxy bool + ProxyUrl string + + // 验证码token刷新成功回调 + refreshCTokenCk func(token string) +} + +func (c *Common) SetDeviceID(deviceID string) { + c.DeviceID = deviceID +} + +func (c *Common) SetUserID(userID string) { + c.userID = userID +} + +func (c *Common) SetUserAgent(userAgent string) { + c.UserAgent = userAgent +} + +func (c *Common) SetCaptchaToken(captchaToken string) { + c.captchaToken = captchaToken +} +func (c *Common) GetCaptchaToken() string { + return c.captchaToken +} + +// 刷新验证码token(登录后) +func (c *Common) RefreshCaptchaTokenAtLogin(action, userID string) error { + metas := map[string]string{ + "client_version": c.ClientVersion, + "package_name": c.PackageName, + "user_id": userID, + } + metas["timestamp"], metas["captcha_sign"] = c.GetCaptchaSign() + return c.refreshCaptchaToken(action, metas) +} + +// 刷新验证码token(登录时) +func (c *Common) RefreshCaptchaTokenInLogin(action, username string) error { + metas := make(map[string]string) + if ok, _ := regexp.MatchString(`\w+([-+.]\w+)*@\w+([-.]\w+)*\.\w+([-.]\w+)*`, username); ok { + metas["email"] = username + } else if len(username) >= 11 && len(username) <= 18 { + metas["phone_number"] = username + } else { + metas["username"] = username + } + return c.refreshCaptchaToken(action, metas) +} + +// 获取验证码签名 +func (c *Common) GetCaptchaSign() (timestamp, sign string) { + if len(c.Algorithms) == 0 { + return c.Timestamp, c.CaptchaSign + } + timestamp = fmt.Sprint(time.Now().UnixMilli()) + str := fmt.Sprint(c.ClientID, c.ClientVersion, c.PackageName, c.DeviceID, timestamp) + for _, algorithm := range c.Algorithms { + str = utils.GetMD5EncodeStr(str + algorithm) + } + sign = "1." + str + return +} + +// 刷新验证码token +func (c *Common) refreshCaptchaToken(action string, metas map[string]string) error { + param := CaptchaTokenRequest{ + Action: action, + CaptchaToken: c.captchaToken, + ClientID: c.ClientID, + DeviceID: c.DeviceID, + Meta: metas, + RedirectUri: "xlaccsdk01://xbase.cloud/callback?state=harbor", + } + var e ErrResp + var resp CaptchaTokenResponse + _, err := c.Request(XLUSER_API_URL+"/shield/captcha/init", http.MethodPost, func(req *resty.Request) { + req.SetError(&e).SetBody(param) + }, &resp) + + if err != nil { + return err + } + + if e.IsError() { + return &e + } + + if resp.Url != "" { + return fmt.Errorf(`need verify: Click Here`, resp.Url) + } + + if resp.CaptchaToken == "" { + return fmt.Errorf("empty captchaToken") + } + + if c.refreshCTokenCk != nil { + c.refreshCTokenCk(resp.CaptchaToken) + } + c.SetCaptchaToken(resp.CaptchaToken) + return nil +} + +// Request 只有基础信息的请求 +func (c *Common) Request(url, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := c.client.R().SetHeaders(map[string]string{ + "user-agent": c.UserAgent, + "accept": "application/json;charset=UTF-8", + "x-device-id": c.DeviceID, + "x-client-id": c.ClientID, + "x-client-version": c.ClientVersion, + }) + + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + + reurl := url + + if c.UseProxy { + if strings.HasSuffix(c.ProxyUrl, "/") { + reurl = c.ProxyUrl + url + } else { + reurl = c.ProxyUrl + "/" + url + } + } + + res, err := req.Execute(method, reurl) + if err != nil { + return nil, err + } + + var erron ErrResp + utils.Json.Unmarshal(res.Body(), &erron) + if erron.IsError() { + return nil, &erron + } + + return res.Body(), nil +} + +// 计算文件Gcid +func getGcid(r io.Reader, size int64) (string, error) { + calcBlockSize := func(j int64) int64 { + var psize int64 = 0x40000 + for float64(j)/float64(psize) > 0x200 && psize < 0x200000 { + psize = psize << 1 + } + return psize + } + + hash1 := sha1.New() + hash2 := sha1.New() + readSize := calcBlockSize(size) + for { + hash2.Reset() + if n, err := utils.CopyWithBufferN(hash2, r, readSize); err != nil && n == 0 { + if err != io.EOF { + return "", err + } + break + } + hash1.Write(hash2.Sum(nil)) + } + return hex.EncodeToString(hash1.Sum(nil)), nil +} + +func generateDeviceSign(deviceID, packageName string) string { + + signatureBase := fmt.Sprintf("%s%s%s%s", deviceID, packageName, "1", "appkey") + + sha1Hash := sha1.New() + sha1Hash.Write([]byte(signatureBase)) + sha1Result := sha1Hash.Sum(nil) + + sha1String := hex.EncodeToString(sha1Result) + + md5Hash := md5.New() + md5Hash.Write([]byte(sha1String)) + md5Result := md5Hash.Sum(nil) + + md5String := hex.EncodeToString(md5Result) + + deviceSign := fmt.Sprintf("div101.%s%s", deviceID, md5String) + + return deviceSign +} + +func BuildCustomUserAgent(deviceID, clientID, appName, sdkVersion, clientVersion, packageName, userID string) string { + deviceSign := generateDeviceSign(deviceID, packageName) + var sb strings.Builder + + sb.WriteString(fmt.Sprintf("ANDROID-%s/%s ", appName, clientVersion)) + sb.WriteString("protocolVersion/200 ") + sb.WriteString("accesstype/ ") + sb.WriteString(fmt.Sprintf("clientid/%s ", clientID)) + sb.WriteString(fmt.Sprintf("clientversion/%s ", clientVersion)) + sb.WriteString("action_type/ ") + sb.WriteString("networktype/WIFI ") + sb.WriteString("sessionid/ ") + sb.WriteString(fmt.Sprintf("deviceid/%s ", deviceID)) + sb.WriteString("providername/NONE ") + sb.WriteString(fmt.Sprintf("devicesign/%s ", deviceSign)) + sb.WriteString("refresh_token/ ") + sb.WriteString(fmt.Sprintf("sdkversion/%s ", sdkVersion)) + sb.WriteString(fmt.Sprintf("datetime/%d ", time.Now().UnixMilli())) + sb.WriteString(fmt.Sprintf("usrno/%s ", userID)) + sb.WriteString(fmt.Sprintf("appname/%s ", appName)) + sb.WriteString(fmt.Sprintf("session_origin/ ")) + sb.WriteString(fmt.Sprintf("grant_type/ ")) + sb.WriteString(fmt.Sprintf("appid/ ")) + sb.WriteString(fmt.Sprintf("clientip/ ")) + sb.WriteString(fmt.Sprintf("devicename/Xiaomi_M2004j7ac ")) + sb.WriteString(fmt.Sprintf("osversion/13 ")) + sb.WriteString(fmt.Sprintf("platformversion/10 ")) + sb.WriteString(fmt.Sprintf("accessmode/ ")) + sb.WriteString(fmt.Sprintf("devicemodel/M2004J7AC ")) + + return sb.String() +} diff --git a/drivers/trainbit/driver.go b/drivers/trainbit/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..795b2fb8a2e9f157e246d40e959a4e13c03fe85d --- /dev/null +++ b/drivers/trainbit/driver.go @@ -0,0 +1,142 @@ +package trainbit + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" +) + +type Trainbit struct { + model.Storage + Addition +} + +var apiExpiredate, guid string + +func (d *Trainbit) Config() driver.Config { + return config +} + +func (d *Trainbit) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Trainbit) Init(ctx context.Context) error { + base.HttpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + var err error + apiExpiredate, guid, err = getToken(d.ApiKey, d.AUSHELLPORTAL) + if err != nil { + return err + } + return nil +} + +func (d *Trainbit) Drop(ctx context.Context) error { + return nil +} + +func (d *Trainbit) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + form := make(url.Values) + form.Set("parentid", strings.Split(dir.GetID(), "_")[0]) + res, err := postForm("https://trainbit.com/lib/api/v1/listoffiles", form, apiExpiredate, d.ApiKey, d.AUSHELLPORTAL) + if err != nil { + return nil, err + } + data, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + var jsonData any + json.Unmarshal(data, &jsonData) + if err != nil { + return nil, err + } + object, err := parseRawFileObject(jsonData.(map[string]any)["items"].([]any)) + if err != nil { + return nil, err + } + return object, nil +} + +func (d *Trainbit) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + res, err := get(fmt.Sprintf("https://trainbit.com/files/%s/", strings.Split(file.GetID(), "_")[0]), d.ApiKey, d.AUSHELLPORTAL) + if err != nil { + return nil, err + } + return &model.Link{ + URL: res.Header.Get("Location"), + }, nil +} + +func (d *Trainbit) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + form := make(url.Values) + form.Set("name", local2provider(dirName, true)) + form.Set("parentid", strings.Split(parentDir.GetID(), "_")[0]) + _, err := postForm("https://trainbit.com/lib/api/v1/createfolder", form, apiExpiredate, d.ApiKey, d.AUSHELLPORTAL) + return err +} + +func (d *Trainbit) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + form := make(url.Values) + form.Set("sourceid", strings.Split(srcObj.GetID(), "_")[0]) + form.Set("destinationid", strings.Split(dstDir.GetID(), "_")[0]) + _, err := postForm("https://trainbit.com/lib/api/v1/move", form, apiExpiredate, d.ApiKey, d.AUSHELLPORTAL) + return err +} + +func (d *Trainbit) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + form := make(url.Values) + form.Set("id", strings.Split(srcObj.GetID(), "_")[0]) + form.Set("name", local2provider(newName, srcObj.IsDir())) + _, err := postForm("https://trainbit.com/lib/api/v1/edit", form, apiExpiredate, d.ApiKey, d.AUSHELLPORTAL) + return err +} + +func (d *Trainbit) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotImplement +} + +func (d *Trainbit) Remove(ctx context.Context, obj model.Obj) error { + form := make(url.Values) + form.Set("id", strings.Split(obj.GetID(), "_")[0]) + _, err := postForm("https://trainbit.com/lib/api/v1/delete", form, apiExpiredate, d.ApiKey, d.AUSHELLPORTAL) + return err +} + +func (d *Trainbit) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + endpoint, _ := url.Parse("https://tb28.trainbit.com/api/upload/send_raw/") + query := &url.Values{} + query.Add("q", strings.Split(dstDir.GetID(), "_")[1]) + query.Add("guid", guid) + query.Add("name", url.QueryEscape(local2provider(stream.GetName(), false)+".")) + endpoint.RawQuery = query.Encode() + var total int64 + total = 0 + progressReader := &ProgressReader{ + stream, + func(byteNum int) { + total += int64(byteNum) + up(float64(total) / float64(stream.GetSize()) * 100) + }, + } + req, err := http.NewRequest(http.MethodPost, endpoint.String(), progressReader) + if err != nil { + return err + } + req.Header.Set("Content-Type", "text/json; charset=UTF-8") + _, err = base.HttpClient.Do(req) + return err +} + +var _ driver.Driver = (*Trainbit)(nil) diff --git a/drivers/trainbit/meta.go b/drivers/trainbit/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..59c09d77e1c6c4d848ebb8e5523c7904eec2c842 --- /dev/null +++ b/drivers/trainbit/meta.go @@ -0,0 +1,29 @@ +package trainbit + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootID + AUSHELLPORTAL string `json:"AUSHELLPORTAL" required:"true"` + ApiKey string `json:"apikey" required:"true"` +} + +var config = driver.Config{ + Name: "Trainbit", + LocalSort: false, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: false, + NeedMs: false, + DefaultRoot: "0_000", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Trainbit{} + }) +} diff --git a/drivers/trainbit/types.go b/drivers/trainbit/types.go new file mode 100644 index 0000000000000000000000000000000000000000..4de1a0abdf39495ea4a5d3424b0a08f441dda728 --- /dev/null +++ b/drivers/trainbit/types.go @@ -0,0 +1 @@ +package trainbit \ No newline at end of file diff --git a/drivers/trainbit/util.go b/drivers/trainbit/util.go new file mode 100644 index 0000000000000000000000000000000000000000..afc111a829054fe26e24af790a661ccfdc7b62e5 --- /dev/null +++ b/drivers/trainbit/util.go @@ -0,0 +1,135 @@ +package trainbit + +import ( + "html" + "io" + "net/http" + "net/url" + "regexp" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/model" +) + +type ProgressReader struct { + io.Reader + reporter func(byteNum int) +} + +func (progressReader *ProgressReader) Read(data []byte) (int, error) { + byteNum, err := progressReader.Reader.Read(data) + progressReader.reporter(byteNum) + return byteNum, err +} + +func get(url string, apiKey string, AUSHELLPORTAL string) (*http.Response, error) { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + req.AddCookie(&http.Cookie{ + Name: ".AUSHELLPORTAL", + Value: AUSHELLPORTAL, + MaxAge: 2 * 60, + }) + req.AddCookie(&http.Cookie{ + Name: "retkeyapi", + Value: apiKey, + MaxAge: 2 * 60, + }) + res, err := base.HttpClient.Do(req) + return res, err +} + +func postForm(endpoint string, data url.Values, apiExpiredate string, apiKey string, AUSHELLPORTAL string) (*http.Response, error) { + extData := make(url.Values) + for key, value := range data { + extData[key] = make([]string, len(value)) + copy(extData[key], value) + } + extData.Set("apikey", apiKey) + extData.Set("expiredate", apiExpiredate) + req, err := http.NewRequest(http.MethodPost, endpoint, strings.NewReader(extData.Encode())) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.AddCookie(&http.Cookie{ + Name: ".AUSHELLPORTAL", + Value: AUSHELLPORTAL, + MaxAge: 2 * 60, + }) + req.AddCookie(&http.Cookie{ + Name: "retkeyapi", + Value: apiKey, + MaxAge: 2 * 60, + }) + res, err := base.HttpClient.Do(req) + return res, err +} + +func getToken(apiKey string, AUSHELLPORTAL string) (string, string, error) { + res, err := get("https://trainbit.com/files/", apiKey, AUSHELLPORTAL) + if err != nil { + return "", "", err + } + data, err := io.ReadAll(res.Body) + if err != nil { + return "", "", err + } + text := string(data) + apiExpiredateReg := regexp.MustCompile(`core.api.expiredate = '([^']*)';`) + result := apiExpiredateReg.FindAllStringSubmatch(text, -1) + apiExpiredate := result[0][1] + guidReg := regexp.MustCompile(`app.vars.upload.guid = '([^']*)';`) + result = guidReg.FindAllStringSubmatch(text, -1) + guid := result[0][1] + return apiExpiredate, guid, nil +} + +func local2provider(filename string, isFolder bool) string { + if isFolder { + return filename + } + return filename + ".delete_suffix" +} + +func provider2local(filename string) string { + filename = html.UnescapeString(filename) + index := strings.LastIndex(filename, ".delete_suffix") + if index != -1 { + filename = filename[:index] + } + return filename +} + +func parseRawFileObject(rawObject []any) ([]model.Obj, error) { + objectList := make([]model.Obj, 0) + for _, each := range rawObject { + object := each.(map[string]any) + if object["id"].(string) == "0" { + continue + } + isFolder := int64(object["ty"].(float64)) == 1 + var name string + if object["ext"].(string) != "" { + name = strings.Join([]string{object["name"].(string), object["ext"].(string)}, ".") + } else { + name = object["name"].(string) + } + modified, err := time.Parse("2006/01/02 15:04:05", object["modified"].(string)) + if err != nil { + return nil, err + } + objectList = append(objectList, model.Obj(&model.Object{ + ID: strings.Join([]string{object["id"].(string), strings.Split(object["uploadurl"].(string), "=")[1]}, "_"), + Name: provider2local(name), + Size: int64(object["byte"].(float64)), + Modified: modified.Add(-210 * time.Minute), + IsFolder: isFolder, + })) + } + return objectList, nil +} diff --git a/drivers/url_tree/driver.go b/drivers/url_tree/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..6a45bb7d4e10cbd53863ab60b55606e9a91d9ba9 --- /dev/null +++ b/drivers/url_tree/driver.go @@ -0,0 +1,79 @@ +package url_tree + +import ( + "context" + stdpath "path" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + log "github.com/sirupsen/logrus" +) + +type Urls struct { + model.Storage + Addition + root *Node +} + +func (d *Urls) Config() driver.Config { + return config +} + +func (d *Urls) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Urls) Init(ctx context.Context) error { + node, err := BuildTree(d.UrlStructure, d.HeadSize) + if err != nil { + return err + } + node.calSize() + d.root = node + return nil +} + +func (d *Urls) Drop(ctx context.Context) error { + return nil +} + +func (d *Urls) Get(ctx context.Context, path string) (model.Obj, error) { + node := GetNodeFromRootByPath(d.root, path) + return nodeToObj(node, path) +} + +func (d *Urls) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + node := GetNodeFromRootByPath(d.root, dir.GetPath()) + log.Debugf("path: %s, node: %+v", dir.GetPath(), node) + if node == nil { + return nil, errs.ObjectNotFound + } + if node.isFile() { + return nil, errs.NotFolder + } + return utils.SliceConvert(node.Children, func(node *Node) (model.Obj, error) { + return nodeToObj(node, stdpath.Join(dir.GetPath(), node.Name)) + }) +} + +func (d *Urls) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + node := GetNodeFromRootByPath(d.root, file.GetPath()) + log.Debugf("path: %s, node: %+v", file.GetPath(), node) + if node == nil { + return nil, errs.ObjectNotFound + } + if node.isFile() { + return &model.Link{ + URL: node.Url, + }, nil + } + return nil, errs.NotFile +} + +//func (d *Template) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*Urls)(nil) diff --git a/drivers/url_tree/meta.go b/drivers/url_tree/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..b3ae33dc059e233238ea556c76512244529984df --- /dev/null +++ b/drivers/url_tree/meta.go @@ -0,0 +1,35 @@ +package url_tree + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + // Usually one of two + // driver.RootPath + // driver.RootID + // define other + UrlStructure string `json:"url_structure" type:"text" required:"true" default:"https://jsd.nn.ci/gh/alist-org/alist/README.md\nhttps://jsd.nn.ci/gh/alist-org/alist/README_cn.md\nfolder:\n CONTRIBUTING.md:1635:https://jsd.nn.ci/gh/alist-org/alist/CONTRIBUTING.md\n CODE_OF_CONDUCT.md:2093:https://jsd.nn.ci/gh/alist-org/alist/CODE_OF_CONDUCT.md" help:"structure:FolderName:\n [FileName:][FileSize:][Modified:]Url"` + HeadSize bool `json:"head_size" type:"bool" default:"false" help:"Use head method to get file size, but it may be failed."` +} + +var config = driver.Config{ + Name: "UrlTree", + LocalSort: true, + OnlyLocal: false, + OnlyProxy: false, + NoCache: true, + NoUpload: true, + NeedMs: false, + DefaultRoot: "", + CheckStatus: true, + Alert: "", + NoOverwriteUpload: false, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Urls{} + }) +} diff --git a/drivers/url_tree/types.go b/drivers/url_tree/types.go new file mode 100644 index 0000000000000000000000000000000000000000..7e8ca3d93aed4ea9eea395c1d88f03f1163ab859 --- /dev/null +++ b/drivers/url_tree/types.go @@ -0,0 +1,46 @@ +package url_tree + +// Node is a node in the folder tree +type Node struct { + Url string + Name string + Level int + Modified int64 + Size int64 + Children []*Node +} + +func (node *Node) getByPath(paths []string) *Node { + if len(paths) == 0 || node == nil { + return nil + } + if node.Name != paths[0] { + return nil + } + if len(paths) == 1 { + return node + } + for _, child := range node.Children { + tmp := child.getByPath(paths[1:]) + if tmp != nil { + return tmp + } + } + return nil +} + +func (node *Node) isFile() bool { + return node.Url != "" +} + +func (node *Node) calSize() int64 { + if node.isFile() { + return node.Size + } + var size int64 = 0 + for _, child := range node.Children { + size += child.calSize() + } + node.Size = size + return size +} diff --git a/drivers/url_tree/urls_test.go b/drivers/url_tree/urls_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1512e083c914fe9d025af37f4c17c9d26b59b8b0 --- /dev/null +++ b/drivers/url_tree/urls_test.go @@ -0,0 +1,47 @@ +package url_tree_test + +import ( + "testing" + + "github.com/alist-org/alist/v3/drivers/url_tree" +) + +func testTree() (*url_tree.Node, error) { + text := `folder1: + name1:https://url1 + http://url2 + folder2: + http://url3 + http://url4 + http://url5 +folder3: + http://url6 + http://url7 +http://url8` + return url_tree.BuildTree(text, false) +} + +func TestBuildTree(t *testing.T) { + node, err := testTree() + if err != nil { + t.Errorf("failed to build tree: %+v", err) + } else { + t.Logf("tree: %+v", node) + } +} + +func TestGetNode(t *testing.T) { + root, err := testTree() + if err != nil { + t.Errorf("failed to build tree: %+v", err) + return + } + node := url_tree.GetNodeFromRootByPath(root, "/") + if node != root { + t.Errorf("got wrong node: %+v", node) + } + url3 := url_tree.GetNodeFromRootByPath(root, "/folder1/folder2/url3") + if url3 != root.Children[0].Children[2].Children[0] { + t.Errorf("got wrong node: %+v", url3) + } +} diff --git a/drivers/url_tree/util.go b/drivers/url_tree/util.go new file mode 100644 index 0000000000000000000000000000000000000000..4065218fcc17756c6042e36bf15f4d689f72d4ae --- /dev/null +++ b/drivers/url_tree/util.go @@ -0,0 +1,192 @@ +package url_tree + +import ( + "fmt" + stdpath "path" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + log "github.com/sirupsen/logrus" +) + +// build tree from text, text structure definition: +/** + * FolderName: + * [FileName:][FileSize:][Modified:]Url + */ +/** + * For example: + * folder1: + * name1:url1 + * url2 + * folder2: + * url3 + * url4 + * url5 + * folder3: + * url6 + * url7 + * url8 + */ +// if there are no name, use the last segment of url as name +func BuildTree(text string, headSize bool) (*Node, error) { + lines := strings.Split(text, "\n") + var root = &Node{Level: -1, Name: "root"} + stack := []*Node{root} + for _, line := range lines { + // calculate indent + indent := 0 + for i := 0; i < len(line); i++ { + if line[i] != ' ' { + break + } + indent++ + } + // if indent is not a multiple of 2, it is an error + if indent%2 != 0 { + return nil, fmt.Errorf("the line '%s' is not a multiple of 2", line) + } + // calculate level + level := indent / 2 + line = strings.TrimSpace(line[indent:]) + // if the line is empty, skip + if line == "" { + continue + } + // if level isn't greater than the level of the top of the stack + // it is not the child of the top of the stack + for level <= stack[len(stack)-1].Level { + // pop the top of the stack + stack = stack[:len(stack)-1] + } + // if the line is a folder + if isFolder(line) { + // create a new node + node := &Node{ + Level: level, + Name: strings.TrimSuffix(line, ":"), + } + // add the node to the top of the stack + stack[len(stack)-1].Children = append(stack[len(stack)-1].Children, node) + // push the node to the stack + stack = append(stack, node) + } else { + // if the line is a file + // create a new node + node, err := parseFileLine(line, headSize) + if err != nil { + return nil, err + } + node.Level = level + // add the node to the top of the stack + stack[len(stack)-1].Children = append(stack[len(stack)-1].Children, node) + } + } + return root, nil +} + +func isFolder(line string) bool { + return strings.HasSuffix(line, ":") +} + +// line definition: +// [FileName:][FileSize:][Modified:]Url +func parseFileLine(line string, headSize bool) (*Node, error) { + // if there is no url, it is an error + if !strings.Contains(line, "http://") && !strings.Contains(line, "https://") { + return nil, fmt.Errorf("invalid line: %s, because url is required for file", line) + } + index := strings.Index(line, "http://") + if index == -1 { + index = strings.Index(line, "https://") + } + url := line[index:] + info := line[:index] + node := &Node{ + Url: url, + } + haveSize := false + if index > 0 { + if !strings.HasSuffix(info, ":") { + return nil, fmt.Errorf("invalid line: %s, because file info must end with ':'", line) + } + info = info[:len(info)-1] + if info == "" { + return nil, fmt.Errorf("invalid line: %s, because file name can't be empty", line) + } + infoParts := strings.Split(info, ":") + node.Name = infoParts[0] + if len(infoParts) > 1 { + size, err := strconv.ParseInt(infoParts[1], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid line: %s, because file size must be an integer", line) + } + node.Size = size + haveSize = true + if len(infoParts) > 2 { + modified, err := strconv.ParseInt(infoParts[2], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid line: %s, because file modified must be an unix timestamp", line) + } + node.Modified = modified + } + } + } else { + node.Name = stdpath.Base(url) + } + if !haveSize && headSize { + size, err := getSizeFromUrl(url) + if err != nil { + log.Errorf("get size from url error: %s", err) + } else { + node.Size = size + } + } + return node, nil +} + +func splitPath(path string) []string { + if path == "/" { + return []string{"root"} + } + parts := strings.Split(path, "/") + parts[0] = "root" + return parts +} + +func GetNodeFromRootByPath(root *Node, path string) *Node { + return root.getByPath(splitPath(path)) +} + +func nodeToObj(node *Node, path string) (model.Obj, error) { + if node == nil { + return nil, errs.ObjectNotFound + } + return &model.Object{ + Name: node.Name, + Size: node.Size, + Modified: time.Unix(node.Modified, 0), + IsFolder: !node.isFile(), + Path: path, + }, nil +} + +func getSizeFromUrl(url string) (int64, error) { + res, err := base.RestyClient.R().SetDoNotParseResponse(true).Head(url) + if err != nil { + return 0, err + } + defer res.RawResponse.Body.Close() + if res.StatusCode() >= 300 { + return 0, fmt.Errorf("get size from url %s failed, status code: %d", url, res.StatusCode()) + } + size, err := strconv.ParseInt(res.Header().Get("Content-Length"), 10, 64) + if err != nil { + return 0, err + } + return size, nil +} diff --git a/drivers/uss/driver.go b/drivers/uss/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..447515d8d36c3de748cfbc9ba118ec6cc7d1bb5e --- /dev/null +++ b/drivers/uss/driver.go @@ -0,0 +1,133 @@ +package uss + +import ( + "context" + "fmt" + "net/url" + "path" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/upyun/go-sdk/v3/upyun" +) + +type USS struct { + model.Storage + Addition + client *upyun.UpYun +} + +func (d *USS) Config() driver.Config { + return config +} + +func (d *USS) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *USS) Init(ctx context.Context) error { + d.client = upyun.NewUpYun(&upyun.UpYunConfig{ + Bucket: d.Bucket, + Operator: d.OperatorName, + Password: d.OperatorPassword, + }) + return nil +} + +func (d *USS) Drop(ctx context.Context) error { + return nil +} + +func (d *USS) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + prefix := getKey(dir.GetPath(), true) + objsChan := make(chan *upyun.FileInfo, 10) + var err error + go func() { + err = d.client.List(&upyun.GetObjectsConfig{ + Path: prefix, + ObjectsChan: objsChan, + MaxListObjects: 0, + MaxListLevel: 1, + }) + }() + if err != nil { + return nil, err + } + res := make([]model.Obj, 0) + for obj := range objsChan { + t := obj.Time + f := model.Object{ + Name: obj.Name, + Size: obj.Size, + Modified: t, + IsFolder: obj.IsDir, + } + res = append(res, &f) + } + return res, err +} + +func (d *USS) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + key := getKey(file.GetPath(), false) + host := d.Endpoint + if !strings.Contains(host, "://") { //判断是否包含协议头,否则https + host = "https://" + host + } + u := fmt.Sprintf("%s/%s", host, key) + downExp := time.Hour * time.Duration(d.SignURLExpire) + expireAt := time.Now().Add(downExp).Unix() + upd := url.QueryEscape(path.Base(file.GetPath())) + tokenOrPassword := d.AntiTheftChainToken + if tokenOrPassword == "" { + tokenOrPassword = d.OperatorPassword + } + signStr := strings.Join([]string{tokenOrPassword, fmt.Sprint(expireAt), fmt.Sprintf("/%s", key)}, "&") + upt := utils.GetMD5EncodeStr(signStr)[12:20] + fmt.Sprint(expireAt) + link := fmt.Sprintf("%s?_upd=%s&_upt=%s", u, upd, upt) + return &model.Link{URL: link}, nil +} + +func (d *USS) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + return d.client.Mkdir(getKey(path.Join(parentDir.GetPath(), dirName), true)) +} + +func (d *USS) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + return d.client.Move(&upyun.MoveObjectConfig{ + SrcPath: getKey(srcObj.GetPath(), srcObj.IsDir()), + DestPath: getKey(path.Join(dstDir.GetPath(), srcObj.GetName()), srcObj.IsDir()), + }) +} + +func (d *USS) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + return d.client.Move(&upyun.MoveObjectConfig{ + SrcPath: getKey(srcObj.GetPath(), srcObj.IsDir()), + DestPath: getKey(path.Join(path.Dir(srcObj.GetPath()), newName), srcObj.IsDir()), + }) +} + +func (d *USS) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return d.client.Copy(&upyun.CopyObjectConfig{ + SrcPath: getKey(srcObj.GetPath(), srcObj.IsDir()), + DestPath: getKey(path.Join(dstDir.GetPath(), srcObj.GetName()), srcObj.IsDir()), + }) +} + +func (d *USS) Remove(ctx context.Context, obj model.Obj) error { + return d.client.Delete(&upyun.DeleteObjectConfig{ + Path: getKey(obj.GetPath(), obj.IsDir()), + Async: false, + }) +} + +func (d *USS) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + // TODO not support cancel?? + return d.client.Put(&upyun.PutObjectConfig{ + Path: getKey(path.Join(dstDir.GetPath(), stream.GetName()), false), + Reader: stream, + }) +} + +var _ driver.Driver = (*USS)(nil) diff --git a/drivers/uss/meta.go b/drivers/uss/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..388df589cb2941b3cd9da6d4ef631b51e6cfde8e --- /dev/null +++ b/drivers/uss/meta.go @@ -0,0 +1,29 @@ +package uss + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + Bucket string `json:"bucket" required:"true"` + Endpoint string `json:"endpoint" required:"true"` + OperatorName string `json:"operator_name" required:"true"` + OperatorPassword string `json:"operator_password" required:"true"` + AntiTheftChainToken string `json:"anti_theft_chain_token" required:"false" default:""` + //CustomHost string `json:"custom_host"` //Endpoint与CustomHost作用相同,去除 + SignURLExpire int `json:"sign_url_expire" type:"number" default:"4"` +} + +var config = driver.Config{ + Name: "USS", + LocalSort: true, + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &USS{} + }) +} diff --git a/drivers/uss/types.go b/drivers/uss/types.go new file mode 100644 index 0000000000000000000000000000000000000000..169aaba047c4325550fa8042115a6489f4b945f9 --- /dev/null +++ b/drivers/uss/types.go @@ -0,0 +1 @@ +package uss diff --git a/drivers/uss/util.go b/drivers/uss/util.go new file mode 100644 index 0000000000000000000000000000000000000000..8b57a0b04a00f20d1ef711b1b644c91dd555daa8 --- /dev/null +++ b/drivers/uss/util.go @@ -0,0 +1,13 @@ +package uss + +import "strings" + +// do others that not defined in Driver interface + +func getKey(path string, dir bool) string { + path = strings.TrimPrefix(path, "/") + if dir { + path += "/" + } + return path +} diff --git a/drivers/virtual/driver.go b/drivers/virtual/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..d5d37e6bac1d978b4ea718296c995d71c5e74a46 --- /dev/null +++ b/drivers/virtual/driver.go @@ -0,0 +1,112 @@ +package virtual + +import ( + "context" + "io" + "time" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils/random" +) + +type Virtual struct { + model.Storage + Addition +} + +func (d *Virtual) Config() driver.Config { + return config +} + +func (d *Virtual) Init(ctx context.Context) error { + return nil +} + +func (d *Virtual) Drop(ctx context.Context) error { + return nil +} + +func (d *Virtual) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Virtual) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + var res []model.Obj + for i := 0; i < d.NumFile; i++ { + res = append(res, d.genObj(false)) + } + for i := 0; i < d.NumFolder; i++ { + res = append(res, d.genObj(true)) + } + return res, nil +} + +type DummyMFile struct { + io.Reader +} + +func (f DummyMFile) Read(p []byte) (n int, err error) { + return f.Reader.Read(p) +} + +func (f DummyMFile) ReadAt(p []byte, off int64) (n int, err error) { + return f.Reader.Read(p) +} + +func (f DummyMFile) Close() error { + return nil +} + +func (DummyMFile) Seek(offset int64, whence int) (int64, error) { + return offset, nil +} + +func (d *Virtual) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + return &model.Link{ + MFile: DummyMFile{Reader: random.Rand}, + }, nil +} + +func (d *Virtual) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + dir := &model.Object{ + Name: dirName, + Size: 0, + IsFolder: true, + Modified: time.Now(), + } + return dir, nil +} + +func (d *Virtual) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + return srcObj, nil +} + +func (d *Virtual) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + obj := &model.Object{ + Name: newName, + Size: srcObj.GetSize(), + IsFolder: srcObj.IsDir(), + Modified: time.Now(), + } + return obj, nil +} + +func (d *Virtual) Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + return srcObj, nil +} + +func (d *Virtual) Remove(ctx context.Context, obj model.Obj) error { + return nil +} + +func (d *Virtual) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + file := &model.Object{ + Name: stream.GetName(), + Size: stream.GetSize(), + Modified: time.Now(), + } + return file, nil +} + +var _ driver.Driver = (*Virtual)(nil) diff --git a/drivers/virtual/meta.go b/drivers/virtual/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..cb24b0295db308ec432ef5ca290ed2fbca0213f8 --- /dev/null +++ b/drivers/virtual/meta.go @@ -0,0 +1,28 @@ +package virtual + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootPath + NumFile int `json:"num_file" type:"number" default:"30" required:"true"` + NumFolder int `json:"num_folder" type:"number" default:"30" required:"true"` + MaxFileSize int64 `json:"max_file_size" type:"number" default:"1073741824" required:"true"` + MinFileSize int64 `json:"min_file_size" type:"number" default:"1048576" required:"true"` +} + +var config = driver.Config{ + Name: "Virtual", + OnlyLocal: true, + LocalSort: true, + NeedMs: true, + //NoCache: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Virtual{} + }) +} diff --git a/drivers/virtual/util.go b/drivers/virtual/util.go new file mode 100644 index 0000000000000000000000000000000000000000..5ed8314c53e2f6e8e1007e133ed530c7f32bc01c --- /dev/null +++ b/drivers/virtual/util.go @@ -0,0 +1,22 @@ +package virtual + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils/random" +) + +func (d *Virtual) genObj(dir bool) model.Obj { + obj := &model.Object{ + Name: random.String(10), + Size: 0, + IsFolder: true, + Modified: time.Now(), + } + if !dir { + obj.Size = random.RangeInt64(d.MinFileSize, d.MaxFileSize) + obj.IsFolder = false + } + return obj +} diff --git a/drivers/vtencent/drive.go b/drivers/vtencent/drive.go new file mode 100644 index 0000000000000000000000000000000000000000..36a9167234e7285a24a37eb206f9e99e27ac264c --- /dev/null +++ b/drivers/vtencent/drive.go @@ -0,0 +1,210 @@ +package vtencent + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/cron" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +type Vtencent struct { + model.Storage + Addition + cron *cron.Cron + config driver.Config + conf Conf +} + +func (d *Vtencent) Config() driver.Config { + return d.config +} + +func (d *Vtencent) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Vtencent) Init(ctx context.Context) error { + tfUid, err := d.LoadUser() + if err != nil { + d.Status = err.Error() + op.MustSaveDriverStorage(d) + return nil + } + d.Addition.TfUid = tfUid + op.MustSaveDriverStorage(d) + d.cron = cron.NewCron(time.Hour * 12) + d.cron.Do(func() { + _, err := d.LoadUser() + if err != nil { + d.Status = err.Error() + op.MustSaveDriverStorage(d) + } + }) + return nil +} + +func (d *Vtencent) Drop(ctx context.Context) error { + if d.cron != nil { + d.cron.Stop() + } + return nil +} + +func (d *Vtencent) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.GetFiles(dir.GetID()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *Vtencent) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + form := fmt.Sprintf(`{"MaterialIds":["%s"]}`, file.GetID()) + var dat map[string]interface{} + if err := json.Unmarshal([]byte(form), &dat); err != nil { + return nil, err + } + var resps RspDown + api := "https://api.vs.tencent.com/SaaS/Material/DescribeMaterialDownloadUrl" + rsp, err := d.request(api, http.MethodPost, func(req *resty.Request) { + req.SetBody(dat) + }, &resps) + if err != nil { + return nil, err + } + if err := json.Unmarshal(rsp, &resps); err != nil { + return nil, err + } + if len(resps.Data.DownloadURLInfoSet) == 0 { + return nil, err + } + u := resps.Data.DownloadURLInfoSet[0].DownloadURL + link := &model.Link{ + URL: u, + Header: http.Header{ + "Referer": []string{d.conf.referer}, + "User-Agent": []string{d.conf.ua}, + }, + Concurrency: 2, + PartSize: 10 * utils.MB, + } + if file.GetSize() == 0 { + link.Concurrency = 0 + link.PartSize = 0 + } + return link, nil +} + +func (d *Vtencent) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + classId, err := strconv.Atoi(parentDir.GetID()) + if err != nil { + return err + } + _, err = d.request("https://api.vs.tencent.com/PaaS/Material/CreateClass", http.MethodPost, func(req *resty.Request) { + req.SetBody(base.Json{ + "Owner": base.Json{ + "Type": "PERSON", + "Id": d.TfUid, + }, + "ParentClassId": classId, + "Name": dirName, + "VerifySign": ""}) + }, nil) + return err +} + +func (d *Vtencent) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + srcType := "MATERIAL" + if srcObj.IsDir() { + srcType = "CLASS" + } + form := fmt.Sprintf(`{"SourceInfos":[ + {"Owner":{"Id":"%s","Type":"PERSON"}, + "Resource":{"Type":"%s","Id":"%s"}} + ], + "Destination":{"Owner":{"Id":"%s","Type":"PERSON"}, + "Resource":{"Type":"CLASS","Id":"%s"}} + }`, d.TfUid, srcType, srcObj.GetID(), d.TfUid, dstDir.GetID()) + var dat map[string]interface{} + if err := json.Unmarshal([]byte(form), &dat); err != nil { + return err + } + _, err := d.request("https://api.vs.tencent.com/PaaS/Material/MoveResource", http.MethodPost, func(req *resty.Request) { + req.SetBody(dat) + }, nil) + return err +} + +func (d *Vtencent) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + api := "https://api.vs.tencent.com/PaaS/Material/ModifyMaterial" + form := fmt.Sprintf(`{ + "Owner":{"Type":"PERSON","Id":"%s"}, + "MaterialId":"%s","Name":"%s"}`, d.TfUid, srcObj.GetID(), newName) + if srcObj.IsDir() { + classId, err := strconv.Atoi(srcObj.GetID()) + if err != nil { + return err + } + api = "https://api.vs.tencent.com/PaaS/Material/ModifyClass" + form = fmt.Sprintf(`{"Owner":{"Type":"PERSON","Id":"%s"}, + "ClassId":%d,"Name":"%s"}`, d.TfUid, classId, newName) + } + var dat map[string]interface{} + if err := json.Unmarshal([]byte(form), &dat); err != nil { + return err + } + _, err := d.request(api, http.MethodPost, func(req *resty.Request) { + req.SetBody(dat) + }, nil) + return err +} + +func (d *Vtencent) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + // TODO copy obj, optional + return errs.NotImplement +} + +func (d *Vtencent) Remove(ctx context.Context, obj model.Obj) error { + srcType := "MATERIAL" + if obj.IsDir() { + srcType = "CLASS" + } + form := fmt.Sprintf(`{ + "SourceInfos":[ + {"Owner":{"Type":"PERSON","Id":"%s"}, + "Resource":{"Type":"%s","Id":"%s"}} + ] + }`, d.TfUid, srcType, obj.GetID()) + var dat map[string]interface{} + if err := json.Unmarshal([]byte(form), &dat); err != nil { + return err + } + _, err := d.request("https://api.vs.tencent.com/PaaS/Material/DeleteResource", http.MethodPost, func(req *resty.Request) { + req.SetBody(dat) + }, nil) + return err +} + +func (d *Vtencent) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + err := d.FileUpload(ctx, dstDir, stream, up) + return err +} + +//func (d *Vtencent) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*Vtencent)(nil) diff --git a/drivers/vtencent/meta.go b/drivers/vtencent/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..3bb6cf746399fa08b80cbacc6a0860e2fdc8fbbe --- /dev/null +++ b/drivers/vtencent/meta.go @@ -0,0 +1,39 @@ +package vtencent + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + driver.RootID + Cookie string `json:"cookie" required:"true"` + TfUid string `json:"tf_uid"` + OrderBy string `json:"order_by" type:"select" options:"Name,Size,UpdateTime,CreatTime"` + OrderDirection string `json:"order_direction" type:"select" options:"Asc,Desc"` +} + +type Conf struct { + ua string + referer string + origin string +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Vtencent{ + config: driver.Config{ + Name: "VTencent", + OnlyProxy: true, + OnlyLocal: false, + DefaultRoot: "9", + NoOverwriteUpload: true, + }, + conf: Conf{ + ua: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) quark-cloud-drive/2.5.20 Chrome/100.0.4896.160 Electron/18.3.5.4-b478491100 Safari/537.36 Channel/pckk_other_ch", + referer: "https://app.v.tencent.com/", + origin: "https://app.v.tencent.com", + }, + } + }) +} diff --git a/drivers/vtencent/signature.go b/drivers/vtencent/signature.go new file mode 100644 index 0000000000000000000000000000000000000000..14fda9bdc217bbf02850a3be4743a99e880c2b2a --- /dev/null +++ b/drivers/vtencent/signature.go @@ -0,0 +1,33 @@ +package vtencent + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/hex" +) + +func QSignatureKey(timeKey string, signPath string, key string) string { + signKey := hmac.New(sha1.New, []byte(key)) + signKey.Write([]byte(timeKey)) + signKeyBytes := signKey.Sum(nil) + signKeyHex := hex.EncodeToString(signKeyBytes) + sha := sha1.New() + sha.Write([]byte(signPath)) + shaBytes := sha.Sum(nil) + shaHex := hex.EncodeToString(shaBytes) + + O := "sha1\n" + timeKey + "\n" + shaHex + "\n" + dataSignKey := hmac.New(sha1.New, []byte(signKeyHex)) + dataSignKey.Write([]byte(O)) + dataSignKeyBytes := dataSignKey.Sum(nil) + dataSignKeyHex := hex.EncodeToString(dataSignKeyBytes) + return dataSignKeyHex +} + +func QTwoSignatureKey(timeKey string, key string) string { + signKey := hmac.New(sha1.New, []byte(key)) + signKey.Write([]byte(timeKey)) + signKeyBytes := signKey.Sum(nil) + signKeyHex := hex.EncodeToString(signKeyBytes) + return signKeyHex +} diff --git a/drivers/vtencent/types.go b/drivers/vtencent/types.go new file mode 100644 index 0000000000000000000000000000000000000000..b967481e25385ce0dc26f426b66d4dddff89403d --- /dev/null +++ b/drivers/vtencent/types.go @@ -0,0 +1,252 @@ +package vtencent + +import ( + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type RespErr struct { + Code string `json:"Code"` + Message string `json:"Message"` +} + +type Reqfiles struct { + ScrollToken string `json:"ScrollToken"` + Text string `json:"Text"` + Offset int `json:"Offset"` + Limit int `json:"Limit"` + Sort struct { + Field string `json:"Field"` + Order string `json:"Order"` + } `json:"Sort"` + CreateTimeRanges []any `json:"CreateTimeRanges"` + MaterialTypes []any `json:"MaterialTypes"` + ReviewStatuses []any `json:"ReviewStatuses"` + Tags []any `json:"Tags"` + SearchScopes []struct { + Owner struct { + Type string `json:"Type"` + ID string `json:"Id"` + } `json:"Owner"` + ClassID int `json:"ClassId"` + SearchOneDepth bool `json:"SearchOneDepth"` + } `json:"SearchScopes"` +} + +type File struct { + Type string `json:"Type"` + ClassInfo struct { + ClassID int `json:"ClassId"` + Name string `json:"Name"` + UpdateTime time.Time `json:"UpdateTime"` + CreateTime time.Time `json:"CreateTime"` + FileInboxID string `json:"FileInboxId"` + Owner struct { + Type string `json:"Type"` + ID string `json:"Id"` + } `json:"Owner"` + ClassPath string `json:"ClassPath"` + ParentClassID int `json:"ParentClassId"` + AttachmentInfo struct { + SubClassCount int `json:"SubClassCount"` + MaterialCount int `json:"MaterialCount"` + Size int64 `json:"Size"` + } `json:"AttachmentInfo"` + ClassPreviewURLSet []string `json:"ClassPreviewUrlSet"` + } `json:"ClassInfo"` + MaterialInfo struct { + BasicInfo struct { + MaterialID string `json:"MaterialId"` + MaterialType string `json:"MaterialType"` + Name string `json:"Name"` + CreateTime time.Time `json:"CreateTime"` + UpdateTime time.Time `json:"UpdateTime"` + ClassPath string `json:"ClassPath"` + ClassID int `json:"ClassId"` + TagInfoSet []any `json:"TagInfoSet"` + TagSet []any `json:"TagSet"` + PreviewURL string `json:"PreviewUrl"` + MediaURL string `json:"MediaUrl"` + UnifiedMediaPreviewURL string `json:"UnifiedMediaPreviewUrl"` + Owner struct { + Type string `json:"Type"` + ID string `json:"Id"` + } `json:"Owner"` + PermissionSet any `json:"PermissionSet"` + PermissionInfoSet []any `json:"PermissionInfoSet"` + TfUID string `json:"TfUid"` + GroupID string `json:"GroupId"` + VersionMaterialIDSet []any `json:"VersionMaterialIdSet"` + FileType string `json:"FileType"` + CmeMaterialPlayList []any `json:"CmeMaterialPlayList"` + Status string `json:"Status"` + DownloadSwitch string `json:"DownloadSwitch"` + } `json:"BasicInfo"` + MediaInfo struct { + Width int `json:"Width"` + Height int `json:"Height"` + Size int `json:"Size"` + Duration float64 `json:"Duration"` + Fps int `json:"Fps"` + BitRate int `json:"BitRate"` + Codec string `json:"Codec"` + MediaType string `json:"MediaType"` + FavoriteStatus string `json:"FavoriteStatus"` + } `json:"MediaInfo"` + MaterialStatus struct { + ContentReviewStatus string `json:"ContentReviewStatus"` + EditorUsableStatus string `json:"EditorUsableStatus"` + UnifiedPreviewStatus string `json:"UnifiedPreviewStatus"` + EditPreviewImageSpiritStatus string `json:"EditPreviewImageSpiritStatus"` + TranscodeStatus string `json:"TranscodeStatus"` + AdaptiveStreamingStatus string `json:"AdaptiveStreamingStatus"` + StreamConnectable string `json:"StreamConnectable"` + AiAnalysisStatus string `json:"AiAnalysisStatus"` + AiRecognitionStatus string `json:"AiRecognitionStatus"` + } `json:"MaterialStatus"` + ImageMaterial struct { + Height int `json:"Height"` + Width int `json:"Width"` + Size int `json:"Size"` + MaterialURL string `json:"MaterialUrl"` + Resolution string `json:"Resolution"` + VodFileID string `json:"VodFileId"` + OriginalURL string `json:"OriginalUrl"` + } `json:"ImageMaterial"` + VideoMaterial struct { + MetaData struct { + Size int `json:"Size"` + Container string `json:"Container"` + Bitrate int `json:"Bitrate"` + Height int `json:"Height"` + Width int `json:"Width"` + Duration float64 `json:"Duration"` + Rotate int `json:"Rotate"` + VideoStreamInfoSet []struct { + Bitrate int `json:"Bitrate"` + Height int `json:"Height"` + Width int `json:"Width"` + Codec string `json:"Codec"` + Fps int `json:"Fps"` + } `json:"VideoStreamInfoSet"` + AudioStreamInfoSet []struct { + Bitrate int `json:"Bitrate"` + SamplingRate int `json:"SamplingRate"` + Codec string `json:"Codec"` + } `json:"AudioStreamInfoSet"` + } `json:"MetaData"` + ImageSpriteInfo any `json:"ImageSpriteInfo"` + MaterialURL string `json:"MaterialUrl"` + CoverURL string `json:"CoverUrl"` + Resolution string `json:"Resolution"` + VodFileID string `json:"VodFileId"` + OriginalURL string `json:"OriginalUrl"` + AudioWaveformURL string `json:"AudioWaveformUrl"` + SubtitleURL string `json:"SubtitleUrl"` + TranscodeInfoSet []any `json:"TranscodeInfoSet"` + ImageSpriteInfoSet []any `json:"ImageSpriteInfoSet"` + } `json:"VideoMaterial"` + } `json:"MaterialInfo"` +} + +type RspFiles struct { + Code string `json:"Code"` + Message string `json:"Message"` + EnglishMessage string `json:"EnglishMessage"` + Data struct { + TotalCount int `json:"TotalCount"` + ResourceInfoSet []File `json:"ResourceInfoSet"` + ScrollToken string `json:"ScrollToken"` + } `json:"Data"` +} + +type RspDown struct { + Code string `json:"Code"` + Message string `json:"Message"` + EnglishMessage string `json:"EnglishMessage"` + Data struct { + DownloadURLInfoSet []struct { + MaterialID string `json:"MaterialId"` + DownloadURL string `json:"DownloadUrl"` + } `json:"DownloadUrlInfoSet"` + } `json:"Data"` +} + +type RspCreatrMaterial struct { + Code string `json:"Code"` + Message string `json:"Message"` + EnglishMessage string `json:"EnglishMessage"` + Data struct { + UploadContext string `json:"UploadContext"` + VodUploadSign string `json:"VodUploadSign"` + QuickUpload bool `json:"QuickUpload"` + } `json:"Data"` +} + +type RspApplyUploadUGC struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + Video struct { + StorageSignature string `json:"storageSignature"` + StoragePath string `json:"storagePath"` + } `json:"video"` + StorageAppID int `json:"storageAppId"` + StorageBucket string `json:"storageBucket"` + StorageRegion string `json:"storageRegion"` + StorageRegionV5 string `json:"storageRegionV5"` + Domain string `json:"domain"` + VodSessionKey string `json:"vodSessionKey"` + TempCertificate struct { + SecretID string `json:"secretId"` + SecretKey string `json:"secretKey"` + Token string `json:"token"` + ExpiredTime int `json:"expiredTime"` + } `json:"tempCertificate"` + AppID int `json:"appId"` + Timestamp int `json:"timestamp"` + StorageRegionV50 string `json:"StorageRegionV5"` + MiniProgramAccelerateHost string `json:"MiniProgramAccelerateHost"` + } `json:"data"` +} + +type RspCommitUploadUGC struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + Video struct { + URL string `json:"url"` + VerifyContent string `json:"verify_content"` + } `json:"video"` + FileID string `json:"fileId"` + } `json:"data"` +} + +type RspFinishUpload struct { + Code string `json:"Code"` + Message string `json:"Message"` + EnglishMessage string `json:"EnglishMessage"` + Data struct { + MaterialID string `json:"MaterialId"` + } `json:"Data"` +} + +func fileToObj(f File) *model.Object { + obj := &model.Object{} + if f.Type == "CLASS" { + obj.Name = f.ClassInfo.Name + obj.ID = strconv.Itoa(f.ClassInfo.ClassID) + obj.IsFolder = true + obj.Modified = f.ClassInfo.CreateTime + obj.Size = 0 + } else if f.Type == "MATERIAL" { + obj.Name = f.MaterialInfo.BasicInfo.Name + obj.ID = f.MaterialInfo.BasicInfo.MaterialID + obj.IsFolder = false + obj.Modified = f.MaterialInfo.BasicInfo.CreateTime + obj.Size = int64(f.MaterialInfo.MediaInfo.Size) + } + return obj +} diff --git a/drivers/vtencent/util.go b/drivers/vtencent/util.go new file mode 100644 index 0000000000000000000000000000000000000000..ba87f1abe519962dba733b795f36739342cbd80f --- /dev/null +++ b/drivers/vtencent/util.go @@ -0,0 +1,300 @@ +package vtencent + +import ( + "context" + "crypto/sha1" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "path" + "strconv" + "strings" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/go-resty/resty/v2" +) + +func (d *Vtencent) request(url, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "cookie": d.Cookie, + "content-type": "application/json", + "origin": d.conf.origin, + "referer": d.conf.referer, + }) + if callback != nil { + callback(req) + } else { + req.SetBody("{}") + } + if resp != nil { + req.SetResult(resp) + } + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + code := utils.Json.Get(res.Body(), "Code").ToString() + if code != "Success" { + switch code { + case "AuthFailure.SessionInvalid": + if err != nil { + return nil, errors.New(code) + } + default: + return nil, errors.New(code) + } + return d.request(url, method, callback, resp) + } + return res.Body(), nil +} + +func (d *Vtencent) ugcRequest(url, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + req := base.RestyClient.R() + req.SetHeaders(map[string]string{ + "cookie": d.Cookie, + "content-type": "application/json", + "origin": d.conf.origin, + "referer": d.conf.referer, + }) + if callback != nil { + callback(req) + } else { + req.SetBody("{}") + } + if resp != nil { + req.SetResult(resp) + } + res, err := req.Execute(method, url) + if err != nil { + return nil, err + } + code := utils.Json.Get(res.Body(), "Code").ToInt() + if code != 0 { + message := utils.Json.Get(res.Body(), "message").ToString() + if len(message) == 0 { + message = utils.Json.Get(res.Body(), "msg").ToString() + } + return nil, errors.New(message) + } + return res.Body(), nil +} + +func (d *Vtencent) LoadUser() (string, error) { + api := "https://api.vs.tencent.com/SaaS/Account/DescribeAccount" + res, err := d.request(api, http.MethodPost, func(req *resty.Request) {}, nil) + if err != nil { + return "", err + } + return utils.Json.Get(res, "Data", "TfUid").ToString(), nil +} + +func (d *Vtencent) GetFiles(dirId string) ([]File, error) { + var res []File + //offset := 0 + for { + api := "https://api.vs.tencent.com/PaaS/Material/SearchResource" + form := fmt.Sprintf(`{ + "Text":"", + "Text":"", + "Offset":%d, + "Limit":50, + "Sort":{"Field":"%s","Order":"%s"}, + "CreateTimeRanges":[], + "MaterialTypes":[], + "ReviewStatuses":[], + "Tags":[], + "SearchScopes":[{"Owner":{"Type":"PERSON","Id":"%s"},"ClassId":%s,"SearchOneDepth":true}] + }`, len(res), d.Addition.OrderBy, d.Addition.OrderDirection, d.TfUid, dirId) + var resp RspFiles + _, err := d.request(api, http.MethodPost, func(req *resty.Request) { + req.SetBody(form).ForceContentType("application/json") + }, &resp) + if err != nil { + return nil, err + } + res = append(res, resp.Data.ResourceInfoSet...) + if len(resp.Data.ResourceInfoSet) <= 0 || len(res) >= resp.Data.TotalCount { + break + } + } + return res, nil +} + +func (d *Vtencent) CreateUploadMaterial(classId int, fileName string, UploadSummaryKey string) (RspCreatrMaterial, error) { + api := "https://api.vs.tencent.com/PaaS/Material/CreateUploadMaterial" + form := base.Json{"Owner": base.Json{"Type": "PERSON", "Id": d.TfUid}, + "MaterialType": "VIDEO", "Name": fileName, "ClassId": classId, + "UploadSummaryKey": UploadSummaryKey} + var resps RspCreatrMaterial + _, err := d.request(api, http.MethodPost, func(req *resty.Request) { + req.SetBody(form).ForceContentType("application/json") + }, &resps) + if err != nil { + return RspCreatrMaterial{}, err + } + return resps, nil +} + +func (d *Vtencent) ApplyUploadUGC(signature string, stream model.FileStreamer) (RspApplyUploadUGC, error) { + api := "https://vod2.qcloud.com/v3/index.php?Action=ApplyUploadUGC" + form := base.Json{ + "signature": signature, + "videoName": stream.GetName(), + "videoType": strings.ReplaceAll(path.Ext(stream.GetName()), ".", ""), + "videoSize": stream.GetSize(), + } + var resps RspApplyUploadUGC + _, err := d.ugcRequest(api, http.MethodPost, func(req *resty.Request) { + req.SetBody(form).ForceContentType("application/json") + }, &resps) + if err != nil { + return RspApplyUploadUGC{}, err + } + return resps, nil +} + +func (d *Vtencent) CommitUploadUGC(signature string, vodSessionKey string) (RspCommitUploadUGC, error) { + api := "https://vod2.qcloud.com/v3/index.php?Action=CommitUploadUGC" + form := base.Json{ + "signature": signature, + "vodSessionKey": vodSessionKey, + } + var resps RspCommitUploadUGC + rsp, err := d.ugcRequest(api, http.MethodPost, func(req *resty.Request) { + req.SetBody(form).ForceContentType("application/json") + }, &resps) + if err != nil { + return RspCommitUploadUGC{}, err + } + if len(resps.Data.Video.URL) == 0 { + return RspCommitUploadUGC{}, errors.New(string(rsp)) + } + return resps, nil +} + +func (d *Vtencent) FinishUploadMaterial(SummaryKey string, VodVerifyKey string, UploadContext, VodFileId string) (RspFinishUpload, error) { + api := "https://api.vs.tencent.com/PaaS/Material/FinishUploadMaterial" + form := base.Json{ + "UploadContext": UploadContext, + "VodVerifyKey": VodVerifyKey, + "VodFileId": VodFileId, + "UploadFullKey": SummaryKey} + var resps RspFinishUpload + rsp, err := d.request(api, http.MethodPost, func(req *resty.Request) { + req.SetBody(form).ForceContentType("application/json") + }, &resps) + if err != nil { + return RspFinishUpload{}, err + } + if len(resps.Data.MaterialID) == 0 { + return RspFinishUpload{}, errors.New(string(rsp)) + } + return resps, nil +} + +func (d *Vtencent) FinishHashUploadMaterial(SummaryKey string, UploadContext string) (RspFinishUpload, error) { + api := "https://api.vs.tencent.com/PaaS/Material/FinishUploadMaterial" + var resps RspFinishUpload + form := base.Json{ + "UploadContext": UploadContext, + "UploadFullKey": SummaryKey} + rsp, err := d.request(api, http.MethodPost, func(req *resty.Request) { + req.SetBody(form).ForceContentType("application/json") + }, &resps) + if err != nil { + return RspFinishUpload{}, err + } + if len(resps.Data.MaterialID) == 0 { + return RspFinishUpload{}, errors.New(string(rsp)) + } + return resps, nil +} + +func (d *Vtencent) FileUpload(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + classId, err := strconv.Atoi(dstDir.GetID()) + if err != nil { + return err + } + const chunkLength int64 = 1024 * 1024 * 10 + reader, err := stream.RangeRead(http_range.Range{Start: 0, Length: chunkLength}) + if err != nil { + return err + } + chunkHash, err := utils.HashReader(utils.SHA1, reader) + if err != nil { + return err + } + rspCreatrMaterial, err := d.CreateUploadMaterial(classId, stream.GetName(), chunkHash) + if err != nil { + return err + } + if rspCreatrMaterial.Data.QuickUpload { + SummaryKey := stream.GetHash().GetHash(utils.SHA1) + if len(SummaryKey) < utils.SHA1.Width { + if SummaryKey, err = utils.HashReader(utils.SHA1, stream); err != nil { + return err + } + } + UploadContext := rspCreatrMaterial.Data.UploadContext + _, err = d.FinishHashUploadMaterial(SummaryKey, UploadContext) + if err != nil { + return err + } + return nil + } + hash := sha1.New() + rspUGC, err := d.ApplyUploadUGC(rspCreatrMaterial.Data.VodUploadSign, stream) + if err != nil { + return err + } + params := rspUGC.Data + certificate := params.TempCertificate + cfg := &aws.Config{ + HTTPClient: base.HttpClient, + // S3ForcePathStyle: aws.Bool(true), + Credentials: credentials.NewStaticCredentials(certificate.SecretID, certificate.SecretKey, certificate.Token), + Region: aws.String(params.StorageRegionV5), + Endpoint: aws.String(fmt.Sprintf("cos.%s.myqcloud.com", params.StorageRegionV5)), + } + ss, err := session.NewSession(cfg) + if err != nil { + return err + } + uploader := s3manager.NewUploader(ss) + if stream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { + uploader.PartSize = stream.GetSize() / (s3manager.MaxUploadParts - 1) + } + input := &s3manager.UploadInput{ + Bucket: aws.String(fmt.Sprintf("%s-%d", params.StorageBucket, params.StorageAppID)), + Key: ¶ms.Video.StoragePath, + Body: io.TeeReader(stream, io.MultiWriter(hash, driver.NewProgress(stream.GetSize(), up))), + } + _, err = uploader.UploadWithContext(ctx, input) + if err != nil { + return err + } + rspCommitUGC, err := d.CommitUploadUGC(rspCreatrMaterial.Data.VodUploadSign, rspUGC.Data.VodSessionKey) + if err != nil { + return err + } + VodVerifyKey := rspCommitUGC.Data.Video.VerifyContent + VodFileId := rspCommitUGC.Data.FileID + UploadContext := rspCreatrMaterial.Data.UploadContext + SummaryKey := hex.EncodeToString(hash.Sum(nil)) + _, err = d.FinishUploadMaterial(SummaryKey, VodVerifyKey, UploadContext, VodFileId) + if err != nil { + return err + } + return nil +} diff --git a/drivers/webdav/driver.go b/drivers/webdav/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..5a98057f31f09a106bfe33955d2301ce11b635ea --- /dev/null +++ b/drivers/webdav/driver.go @@ -0,0 +1,132 @@ +package webdav + +import ( + "context" + "io" + "net/http" + "os" + "path" + "time" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/cron" + "github.com/alist-org/alist/v3/pkg/gowebdav" + "github.com/alist-org/alist/v3/pkg/utils" +) + +type WebDav struct { + model.Storage + Addition + client *gowebdav.Client + cron *cron.Cron +} + +func (d *WebDav) Config() driver.Config { + return config +} + +func (d *WebDav) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *WebDav) Init(ctx context.Context) error { + err := d.setClient() + if err == nil { + d.cron = cron.NewCron(time.Hour * 12) + d.cron.Do(func() { + _ = d.setClient() + }) + } + return err +} + +func (d *WebDav) Drop(ctx context.Context) error { + if d.cron != nil { + d.cron.Stop() + } + return nil +} + +func (d *WebDav) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.client.ReadDir(dir.GetPath()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src os.FileInfo) (model.Obj, error) { + return &model.Object{ + Name: src.Name(), + Size: src.Size(), + Modified: src.ModTime(), + IsFolder: src.IsDir(), + }, nil + }) +} + +func (d *WebDav) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + url, header, err := d.client.Link(file.GetPath()) + if err != nil { + return nil, err + } + return &model.Link{ + URL: url, + Header: header, + }, nil +} + +func (d *WebDav) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + return d.client.MkdirAll(path.Join(parentDir.GetPath(), dirName), 0644) +} + +func (d *WebDav) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + return d.client.Rename(getPath(srcObj), path.Join(dstDir.GetPath(), srcObj.GetName()), true) +} + +func (d *WebDav) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + return d.client.Rename(getPath(srcObj), path.Join(path.Dir(srcObj.GetPath()), newName), true) +} + +func (d *WebDav) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return d.client.Copy(getPath(srcObj), path.Join(dstDir.GetPath(), srcObj.GetName()), true) +} + +func (d *WebDav) Remove(ctx context.Context, obj model.Obj) error { + return d.client.RemoveAll(getPath(obj)) +} + +func (d *WebDav) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + callback := func(r *http.Request) { + r.Header.Set("Content-Type", stream.GetMimetype()) + r.ContentLength = stream.GetSize() + } + + // 包装 stream 以跟踪进度 + progressReader := &ProgressReader{ + Reader: stream, + Total: stream.GetSize(), + Progress: up, + } + + // TODO: support cancel + err := d.client.WriteStream(path.Join(dstDir.GetPath(), stream.GetName()), progressReader, 0644, callback) + return err +} + +// ProgressReader 用于跟踪读取进度 +type ProgressReader struct { + Reader io.Reader + Total int64 + Progress driver.UpdateProgress + read int64 +} + +func (pr *ProgressReader) Read(p []byte) (int, error) { + n, err := pr.Reader.Read(p) + if n > 0 { + pr.read += int64(n) + pr.Progress(float64(pr.read) / float64(pr.Total) * 100) + } + return n, err +} + +var _ driver.Driver = (*WebDav)(nil) diff --git a/drivers/webdav/meta.go b/drivers/webdav/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..2294d482a6e9aff0290ce68af3148ce9558bfd50 --- /dev/null +++ b/drivers/webdav/meta.go @@ -0,0 +1,28 @@ +package webdav + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + Vendor string `json:"vendor" type:"select" options:"sharepoint,other" default:"other"` + Address string `json:"address" required:"true"` + Username string `json:"username" required:"true"` + Password string `json:"password" required:"true"` + driver.RootPath + TlsInsecureSkipVerify bool `json:"tls_insecure_skip_verify" default:"false"` +} + +var config = driver.Config{ + Name: "WebDav", + LocalSort: true, + OnlyProxy: true, + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &WebDav{} + }) +} diff --git a/drivers/webdav/odrvcookie/cookie.go b/drivers/webdav/odrvcookie/cookie.go new file mode 100644 index 0000000000000000000000000000000000000000..bcd4f3b203312d44f895dc61a8ff9e4f744fa4e3 --- /dev/null +++ b/drivers/webdav/odrvcookie/cookie.go @@ -0,0 +1,46 @@ +package odrvcookie + +import ( + "net/http" + + "github.com/alist-org/alist/v3/pkg/cookie" +) + +//type SpCookie struct { +// Cookie string +// expire time.Time +//} +// +//func (sp SpCookie) IsExpire() bool { +// return time.Now().After(sp.expire) +//} +// +//var cookiesMap = struct { +// sync.Mutex +// m map[string]*SpCookie +//}{m: make(map[string]*SpCookie)} + +func GetCookie(username, password, siteUrl string) (string, error) { + //cookiesMap.Lock() + //defer cookiesMap.Unlock() + //spCookie, ok := cookiesMap.m[username] + //if ok { + // if !spCookie.IsExpire() { + // log.Debugln("sp use old cookie.") + // return spCookie.Cookie, nil + // } + //} + //log.Debugln("fetch new cookie") + ca := New(username, password, siteUrl) + tokenConf, err := ca.Cookies() + if err != nil { + return "", err + } + return cookie.ToString([]*http.Cookie{&tokenConf.RtFa, &tokenConf.FedAuth}), nil + //spCookie = &SpCookie{ + // Cookie: cookie.ToString([]*http.Cookie{&tokenConf.RtFa, &tokenConf.FedAuth}), + // expire: time.Now().Add(time.Hour * 12), + //} + //cookiesMap.m[username] = spCookie + //return spCookie.Cookie, nil +} diff --git a/drivers/webdav/odrvcookie/fetch.go b/drivers/webdav/odrvcookie/fetch.go new file mode 100644 index 0000000000000000000000000000000000000000..a52fc68be0eb86694fd67db75679ec01ae83ad0b --- /dev/null +++ b/drivers/webdav/odrvcookie/fetch.go @@ -0,0 +1,207 @@ +// Package odrvcookie can fetch authentication cookies for a sharepoint webdav endpoint +package odrvcookie + +import ( + "bytes" + "encoding/xml" + "fmt" + "html/template" + "net/http" + "net/http/cookiejar" + "net/url" + "strings" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "golang.org/x/net/publicsuffix" +) + +// CookieAuth hold the authentication information +// These are username and password as well as the authentication endpoint +type CookieAuth struct { + user string + pass string + endpoint string +} + +// CookieResponse contains the requested cookies +type CookieResponse struct { + RtFa http.Cookie + FedAuth http.Cookie +} + +// SuccessResponse hold a response from the sharepoint webdav +type SuccessResponse struct { + XMLName xml.Name `xml:"Envelope"` + Succ SuccessResponseBody `xml:"Body"` +} + +// SuccessResponseBody is the body of a success response, it holds the token +type SuccessResponseBody struct { + XMLName xml.Name + Type string `xml:"RequestSecurityTokenResponse>TokenType"` + Created time.Time `xml:"RequestSecurityTokenResponse>Lifetime>Created"` + Expires time.Time `xml:"RequestSecurityTokenResponse>Lifetime>Expires"` + Token string `xml:"RequestSecurityTokenResponse>RequestedSecurityToken>BinarySecurityToken"` +} + +// reqString is a template that gets populated with the user data in order to retrieve a "BinarySecurityToken" +const reqString = ` + +http://schemas.xmlsoap.org/ws/2005/02/trust/RST/Issue + +http://www.w3.org/2005/08/addressing/anonymous + +{{ .LoginUrl }} + + + {{ .Username }} + {{ .Password }} + + + + + + + + {{ .Address }} + + +http://schemas.xmlsoap.org/ws/2005/05/identity/NoProofKey +http://schemas.xmlsoap.org/ws/2005/02/trust/Issue +urn:oasis:names:tc:SAML:1.0:assertion + + +` + +// New creates a new CookieAuth struct +func New(pUser, pPass, pEndpoint string) CookieAuth { + retStruct := CookieAuth{ + user: pUser, + pass: pPass, + endpoint: pEndpoint, + } + + return retStruct +} + +// Cookies creates a CookieResponse. It fetches the auth token and then +// retrieves the Cookies +func (ca *CookieAuth) Cookies() (CookieResponse, error) { + spToken, err := ca.getSPToken() + if err != nil { + return CookieResponse{}, err + } + return ca.getSPCookie(spToken) +} + +func (ca *CookieAuth) getSPCookie(conf *SuccessResponse) (CookieResponse, error) { + spRoot, err := url.Parse(ca.endpoint) + if err != nil { + return CookieResponse{}, err + } + + u, err := url.Parse("https://" + spRoot.Host + "/_forms/default.aspx?wa=wsignin1.0") + if err != nil { + return CookieResponse{}, err + } + + // To authenticate with davfs or anything else we need two cookies (rtFa and FedAuth) + // In order to get them we use the token we got earlier and a cookieJar + jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + if err != nil { + return CookieResponse{}, err + } + + client := &http.Client{ + Jar: jar, + } + + // Send the previously acquired Token as a Post parameter + if _, err = client.Post(u.String(), "text/xml", strings.NewReader(conf.Succ.Token)); err != nil { + return CookieResponse{}, err + } + + cookieResponse := CookieResponse{} + for _, cookie := range jar.Cookies(u) { + if (cookie.Name == "rtFa") || (cookie.Name == "FedAuth") { + switch cookie.Name { + case "rtFa": + cookieResponse.RtFa = *cookie + case "FedAuth": + cookieResponse.FedAuth = *cookie + } + } + } + return cookieResponse, err +} + +var loginUrlsMap = map[string]string{ + "com": "https://login.microsoftonline.com", + "cn": "https://login.chinacloudapi.cn", + "us": "https://login.microsoftonline.us", + "de": "https://login.microsoftonline.de", +} + +func getLoginUrl(endpoint string) (string, error) { + spRoot, err := url.Parse(endpoint) + if err != nil { + return "", err + } + domains := strings.Split(spRoot.Host, ".") + tld := domains[len(domains)-1] + loginUrl, ok := loginUrlsMap[tld] + if !ok { + return "", fmt.Errorf("tld %s is not supported", tld) + } + return loginUrl + "/extSTS.srf", nil +} + +func (ca *CookieAuth) getSPToken() (*SuccessResponse, error) { + loginUrl, err := getLoginUrl(ca.endpoint) + if err != nil { + return nil, err + } + reqData := map[string]string{ + "Username": ca.user, + "Password": ca.pass, + "Address": ca.endpoint, + "LoginUrl": loginUrl, + } + + t := template.Must(template.New("authXML").Parse(reqString)) + + buf := &bytes.Buffer{} + if err := t.Execute(buf, reqData); err != nil { + return nil, err + } + + // Execute the first request which gives us an auth token for the sharepoint service + // With this token we can authenticate on the login page and save the returned cookies + req, err := http.NewRequest("POST", loginUrl, buf) + if err != nil { + return nil, err + } + + client := base.HttpClient + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBuf := bytes.Buffer{} + respBuf.ReadFrom(resp.Body) + s := respBuf.Bytes() + + var conf SuccessResponse + err = xml.Unmarshal(s, &conf) + if err != nil { + return nil, err + } + + return &conf, err +} diff --git a/drivers/webdav/types.go b/drivers/webdav/types.go new file mode 100644 index 0000000000000000000000000000000000000000..0541cc2d8125edfed7388648ecb8b2584779d570 --- /dev/null +++ b/drivers/webdav/types.go @@ -0,0 +1 @@ +package webdav diff --git a/drivers/webdav/util.go b/drivers/webdav/util.go new file mode 100644 index 0000000000000000000000000000000000000000..23dc909ff8863ac4566a9d6737b7a6f1537ddb1f --- /dev/null +++ b/drivers/webdav/util.go @@ -0,0 +1,52 @@ +package webdav + +import ( + "crypto/tls" + "net/http" + "net/http/cookiejar" + + "github.com/alist-org/alist/v3/drivers/webdav/odrvcookie" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/gowebdav" +) + +// do others that not defined in Driver interface + +func (d *WebDav) isSharepoint() bool { + return d.Vendor == "sharepoint" +} + +func (d *WebDav) setClient() error { + c := gowebdav.NewClient(d.Address, d.Username, d.Password) + c.SetTransport(&http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tls.Config{InsecureSkipVerify: d.TlsInsecureSkipVerify}, + }) + if d.isSharepoint() { + cookie, err := odrvcookie.GetCookie(d.Username, d.Password, d.Address) + if err == nil { + c.SetInterceptor(func(method string, rq *http.Request) { + rq.Header.Del("Authorization") + rq.Header.Set("Cookie", cookie) + }) + } else { + return err + } + } else { + cookieJar, err := cookiejar.New(nil) + if err == nil { + c.SetJar(cookieJar) + } else { + return err + } + } + d.client = c + return nil +} + +func getPath(obj model.Obj) string { + if obj.IsDir() { + return obj.GetPath() + "/" + } + return obj.GetPath() +} diff --git a/drivers/weiyun/driver.go b/drivers/weiyun/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..e6d5897c313dbe91ce843c93b68ee7d7b4cb82df --- /dev/null +++ b/drivers/weiyun/driver.go @@ -0,0 +1,400 @@ +package weiyun + +import ( + "context" + "fmt" + "io" + "math" + "net/http" + "strconv" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/cron" + "github.com/alist-org/alist/v3/pkg/errgroup" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/avast/retry-go" + weiyunsdkgo "github.com/foxxorcat/weiyun-sdk-go" +) + +type WeiYun struct { + model.Storage + Addition + + client *weiyunsdkgo.WeiYunClient + cron *cron.Cron + rootFolder *Folder + + uploadThread int +} + +func (d *WeiYun) Config() driver.Config { + return config +} + +func (d *WeiYun) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *WeiYun) Init(ctx context.Context) error { + // 限制上传线程数 + d.uploadThread, _ = strconv.Atoi(d.UploadThread) + if d.uploadThread < 4 || d.uploadThread > 32 { + d.uploadThread, d.UploadThread = 4, "4" + } + + d.client = weiyunsdkgo.NewWeiYunClientWithRestyClient(base.NewRestyClient()) + err := d.client.SetCookiesStr(d.Cookies).RefreshCtoken() + if err != nil { + return err + } + + // Cookie过期回调 + d.client.SetOnCookieExpired(func(err error) { + d.Status = err.Error() + op.MustSaveDriverStorage(d) + }) + + // cookie更新回调 + d.client.SetOnCookieUpload(func(c []*http.Cookie) { + d.Cookies = weiyunsdkgo.CookieToString(weiyunsdkgo.ClearCookie(c)) + op.MustSaveDriverStorage(d) + }) + + // qqCookie保活 + if d.client.LoginType() == 1 { + d.cron = cron.NewCron(time.Minute * 5) + d.cron.Do(func() { + d.client.KeepAlive() + }) + } + + // 获取默认根目录dirKey + if d.RootFolderID == "" { + userInfo, err := d.client.DiskUserInfoGet() + if err != nil { + return err + } + d.RootFolderID = userInfo.MainDirKey + } + + // 处理目录ID,找到PdirKey + folders, err := d.client.LibDirPathGet(d.RootFolderID) + if err != nil { + return err + } + if len(folders) == 0 { + return fmt.Errorf("invalid directory ID") + } + + folder := folders[len(folders)-1] + d.rootFolder = &Folder{ + PFolder: &Folder{ + Folder: weiyunsdkgo.Folder{ + DirKey: folder.PdirKey, + }, + }, + Folder: folder.Folder, + } + return nil +} + +func (d *WeiYun) Drop(ctx context.Context) error { + d.client = nil + if d.cron != nil { + d.cron.Stop() + d.cron = nil + } + return nil +} + +func (d *WeiYun) GetRoot(ctx context.Context) (model.Obj, error) { + return d.rootFolder, nil +} + +func (d *WeiYun) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + if folder, ok := dir.(*Folder); ok { + var files []model.Obj + for { + data, err := d.client.DiskDirFileList(folder.GetID(), weiyunsdkgo.WarpParamOption( + weiyunsdkgo.QueryFileOptionOffest(int64(len(files))), + weiyunsdkgo.QueryFileOptionGetType(weiyunsdkgo.FileAndDir), + weiyunsdkgo.QueryFileOptionSort(func() weiyunsdkgo.OrderBy { + switch d.OrderBy { + case "name": + return weiyunsdkgo.FileName + case "size": + return weiyunsdkgo.FileSize + case "updated_at": + return weiyunsdkgo.FileMtime + default: + return weiyunsdkgo.FileName + } + }(), d.OrderDirection == "desc"), + )) + if err != nil { + return nil, err + } + + if files == nil { + files = make([]model.Obj, 0, data.TotalDirCount+data.TotalFileCount) + } + + for _, dir := range data.DirList { + files = append(files, &Folder{ + PFolder: folder, + Folder: dir, + }) + } + + for _, file := range data.FileList { + files = append(files, &File{ + PFolder: folder, + File: file, + }) + } + + if data.FinishFlag || len(data.DirList)+len(data.FileList) == 0 { + return files, nil + } + } + } + return nil, errs.NotSupport +} + +func (d *WeiYun) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if file, ok := file.(*File); ok { + data, err := d.client.DiskFileDownload(weiyunsdkgo.FileParam{PdirKey: file.GetPKey(), FileID: file.GetID()}) + if err != nil { + return nil, err + } + return &model.Link{ + URL: data.DownloadUrl, + Header: http.Header{ + "Cookie": []string{data.CookieName + "=" + data.CookieValue}, + }, + }, nil + } + return nil, errs.NotSupport +} + +func (d *WeiYun) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) { + if folder, ok := parentDir.(*Folder); ok { + newFolder, err := d.client.DiskDirCreate(weiyunsdkgo.FolderParam{ + PPdirKey: folder.GetPKey(), + PdirKey: folder.DirKey, + DirName: dirName, + }) + if err != nil { + return nil, err + } + return &Folder{ + PFolder: folder, + Folder: *newFolder, + }, nil + } + return nil, errs.NotSupport +} + +func (d *WeiYun) Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) { + // TODO: 默认策略为重命名,使用缓存可能出现冲突。微云app也有这个冲突,不知道腾讯怎么搞的 + if dstDir, ok := dstDir.(*Folder); ok { + dstParam := weiyunsdkgo.FolderParam{ + PdirKey: dstDir.GetPKey(), + DirKey: dstDir.GetID(), + DirName: dstDir.GetName(), + } + switch srcObj := srcObj.(type) { + case *File: + err := d.client.DiskFileMove(weiyunsdkgo.FileParam{ + PPdirKey: srcObj.PFolder.GetPKey(), + PdirKey: srcObj.GetPKey(), + FileID: srcObj.GetID(), + FileName: srcObj.GetName(), + }, dstParam) + if err != nil { + return nil, err + } + return &File{ + PFolder: dstDir, + File: srcObj.File, + }, nil + case *Folder: + err := d.client.DiskDirMove(weiyunsdkgo.FolderParam{ + PPdirKey: srcObj.PFolder.GetPKey(), + PdirKey: srcObj.GetPKey(), + DirKey: srcObj.GetID(), + DirName: srcObj.GetName(), + }, dstParam) + if err != nil { + return nil, err + } + return &Folder{ + PFolder: dstDir, + Folder: srcObj.Folder, + }, nil + } + } + return nil, errs.NotSupport +} + +func (d *WeiYun) Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) { + switch srcObj := srcObj.(type) { + case *File: + err := d.client.DiskFileRename(weiyunsdkgo.FileParam{ + PPdirKey: srcObj.PFolder.GetPKey(), + PdirKey: srcObj.GetPKey(), + FileID: srcObj.GetID(), + FileName: srcObj.GetName(), + }, newName) + if err != nil { + return nil, err + } + newFile := srcObj.File + newFile.FileName = newName + newFile.FileCtime = weiyunsdkgo.TimeStamp(time.Now()) + return &File{ + PFolder: srcObj.PFolder, + File: newFile, + }, nil + case *Folder: + err := d.client.DiskDirAttrModify(weiyunsdkgo.FolderParam{ + PPdirKey: srcObj.PFolder.GetPKey(), + PdirKey: srcObj.GetPKey(), + DirKey: srcObj.GetID(), + DirName: srcObj.GetName(), + }, newName) + if err != nil { + return nil, err + } + + newFolder := srcObj.Folder + newFolder.DirName = newName + newFolder.DirCtime = weiyunsdkgo.TimeStamp(time.Now()) + return &Folder{ + PFolder: srcObj.PFolder, + Folder: newFolder, + }, nil + } + return nil, errs.NotSupport +} + +func (d *WeiYun) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + return errs.NotImplement +} + +func (d *WeiYun) Remove(ctx context.Context, obj model.Obj) error { + switch obj := obj.(type) { + case *File: + return d.client.DiskFileDelete(weiyunsdkgo.FileParam{ + PPdirKey: obj.PFolder.GetPKey(), + PdirKey: obj.GetPKey(), + FileID: obj.GetID(), + FileName: obj.GetName(), + }) + case *Folder: + return d.client.DiskDirDelete(weiyunsdkgo.FolderParam{ + PPdirKey: obj.PFolder.GetPKey(), + PdirKey: obj.GetPKey(), + DirKey: obj.GetID(), + DirName: obj.GetName(), + }) + } + return errs.NotSupport +} + +func (d *WeiYun) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { + // NOTE: + // 秒传需要sha1最后一个状态,但sha1无法逆运算需要读完整个文件(或许可以??) + // 服务器支持上传进度恢复,不需要额外实现 + if folder, ok := dstDir.(*Folder); ok { + file, err := stream.CacheFullInTempFile() + if err != nil { + return nil, err + } + + // step 1. + preData, err := d.client.PreUpload(ctx, weiyunsdkgo.UpdloadFileParam{ + PdirKey: folder.GetPKey(), + DirKey: folder.DirKey, + + FileName: stream.GetName(), + FileSize: stream.GetSize(), + File: file, + + ChannelCount: 4, + FileExistOption: 1, + }) + if err != nil { + return nil, err + } + + // not fast upload + if !preData.FileExist { + // step.2 增加上传通道 + if len(preData.ChannelList) < d.uploadThread { + newCh, err := d.client.AddUploadChannel(len(preData.ChannelList), d.uploadThread, preData.UploadAuthData) + if err != nil { + return nil, err + } + preData.ChannelList = append(preData.ChannelList, newCh.AddChannels...) + } + // step.3 上传 + threadG, upCtx := errgroup.NewGroupWithContext(ctx, len(preData.ChannelList), + retry.Attempts(3), + retry.Delay(time.Second), + retry.DelayType(retry.BackOffDelay)) + + for _, channel := range preData.ChannelList { + if utils.IsCanceled(upCtx) { + break + } + + var channel = channel + threadG.Go(func(ctx context.Context) error { + for { + channel.Len = int(math.Min(float64(stream.GetSize()-channel.Offset), float64(channel.Len))) + upData, err := d.client.UploadFile(upCtx, channel, preData.UploadAuthData, + io.NewSectionReader(file, channel.Offset, int64(channel.Len))) + if err != nil { + return err + } + // 上传完成 + if upData.UploadState != 1 { + return nil + } + channel = upData.Channel + } + }) + } + if err = threadG.Wait(); err != nil { + return nil, err + } + } + + return &File{ + PFolder: folder, + File: preData.File, + }, nil + } + return nil, errs.NotSupport +} + +// func (d *WeiYun) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +// } + +var _ driver.Driver = (*WeiYun)(nil) +var _ driver.GetRooter = (*WeiYun)(nil) +var _ driver.MkdirResult = (*WeiYun)(nil) + +// var _ driver.CopyResult = (*WeiYun)(nil) +var _ driver.MoveResult = (*WeiYun)(nil) +var _ driver.Remove = (*WeiYun)(nil) + +var _ driver.PutResult = (*WeiYun)(nil) +var _ driver.RenameResult = (*WeiYun)(nil) diff --git a/drivers/weiyun/meta.go b/drivers/weiyun/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..11200b6b3cc2fefc435cddefcc3cd67fef0a02cc --- /dev/null +++ b/drivers/weiyun/meta.go @@ -0,0 +1,29 @@ +package weiyun + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + RootFolderID string `json:"root_folder_id"` + Cookies string `json:"cookies" required:"true"` + OrderBy string `json:"order_by" type:"select" options:"name,size,updated_at" default:"name"` + OrderDirection string `json:"order_direction" type:"select" options:"asc,desc" default:"asc"` + UploadThread string `json:"upload_thread" default:"4" help:"4<=thread<=32"` +} + +var config = driver.Config{ + Name: "WeiYun", + LocalSort: false, + OnlyProxy: true, + CheckStatus: true, + Alert: "", + NoOverwriteUpload: false, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &WeiYun{} + }) +} diff --git a/drivers/weiyun/types.go b/drivers/weiyun/types.go new file mode 100644 index 0000000000000000000000000000000000000000..664693c80edb3b7353a41dec0686034a26aceb67 --- /dev/null +++ b/drivers/weiyun/types.go @@ -0,0 +1,55 @@ +package weiyun + +import ( + "github.com/alist-org/alist/v3/pkg/utils" + "time" + + weiyunsdkgo "github.com/foxxorcat/weiyun-sdk-go" +) + +type File struct { + PFolder *Folder + weiyunsdkgo.File +} + +func (f *File) GetID() string { return f.FileID } +func (f *File) GetSize() int64 { return f.FileSize } +func (f *File) GetName() string { return f.FileName } +func (f *File) ModTime() time.Time { return time.Time(f.FileMtime) } +func (f *File) IsDir() bool { return false } +func (f *File) GetPath() string { return "" } + +func (f *File) GetPKey() string { + return f.PFolder.DirKey +} +func (f *File) CreateTime() time.Time { + return time.Time(f.FileCtime) +} + +func (f *File) GetHash() utils.HashInfo { + return utils.NewHashInfo(utils.SHA1, f.FileSha) +} + +type Folder struct { + PFolder *Folder + weiyunsdkgo.Folder +} + +func (f *Folder) CreateTime() time.Time { + return time.Time(f.DirCtime) +} + +func (f *Folder) GetHash() utils.HashInfo { + return utils.HashInfo{} +} + +func (f *Folder) GetID() string { return f.DirKey } +func (f *Folder) GetSize() int64 { return 0 } +func (f *Folder) GetName() string { return f.DirName } +func (f *Folder) ModTime() time.Time { return time.Time(f.DirMtime) } +func (f *Folder) IsDir() bool { return true } +func (f *Folder) GetPath() string { return "" } + +func (f *Folder) GetPKey() string { + return f.PFolder.DirKey +} diff --git a/drivers/wopan/driver.go b/drivers/wopan/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..bccce4b1c0a5e2cfbf66ba128fbd6a6c6f5e225f --- /dev/null +++ b/drivers/wopan/driver.go @@ -0,0 +1,172 @@ +package template + +import ( + "context" + "fmt" + "strconv" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" + "github.com/xhofe/wopan-sdk-go" +) + +type Wopan struct { + model.Storage + Addition + client *wopan.WoClient + defaultFamilyID string +} + +func (d *Wopan) Config() driver.Config { + return config +} + +func (d *Wopan) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *Wopan) Init(ctx context.Context) error { + d.client = wopan.DefaultWithRefreshToken(d.RefreshToken) + d.client.SetAccessToken(d.AccessToken) + d.client.OnRefreshToken(func(accessToken, refreshToken string) { + d.AccessToken = accessToken + d.RefreshToken = refreshToken + op.MustSaveDriverStorage(d) + }) + fml, err := d.client.FamilyUserCurrentEncode() + if err != nil { + return err + } + d.defaultFamilyID = strconv.Itoa(fml.DefaultHomeId) + return d.client.InitData() +} + +func (d *Wopan) Drop(ctx context.Context) error { + return nil +} + +func (d *Wopan) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + var res []model.Obj + pageNum := 0 + pageSize := 100 + for { + data, err := d.client.QueryAllFiles(d.getSpaceType(), dir.GetID(), pageNum, pageSize, 0, d.FamilyID, func(req *resty.Request) { + req.SetContext(ctx) + }) + if err != nil { + return nil, err + } + objs, err := utils.SliceConvert(data.Files, fileToObj) + if err != nil { + return nil, err + } + res = append(res, objs...) + if len(data.Files) < pageSize { + break + } + pageNum++ + } + return res, nil +} + +func (d *Wopan) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + if f, ok := file.(*Object); ok { + res, err := d.client.GetDownloadUrlV2([]string{f.FID}, func(req *resty.Request) { + req.SetContext(ctx) + }) + if err != nil { + return nil, err + } + return &model.Link{ + URL: res.List[0].DownloadUrl, + }, nil + } + return nil, fmt.Errorf("unable to convert file to Object") +} + +func (d *Wopan) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + familyID := d.FamilyID + if familyID == "" { + familyID = d.defaultFamilyID + } + _, err := d.client.CreateDirectory(d.getSpaceType(), parentDir.GetID(), dirName, familyID, func(req *resty.Request) { + req.SetContext(ctx) + }) + return err +} + +func (d *Wopan) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + dirList := make([]string, 0) + fileList := make([]string, 0) + if srcObj.IsDir() { + dirList = append(dirList, srcObj.GetID()) + } else { + fileList = append(fileList, srcObj.GetID()) + } + return d.client.MoveFile(dirList, fileList, dstDir.GetID(), + d.getSpaceType(), d.getSpaceType(), + d.FamilyID, d.FamilyID, func(req *resty.Request) { + req.SetContext(ctx) + }) +} + +func (d *Wopan) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + _type := 1 + if srcObj.IsDir() { + _type = 0 + } + return d.client.RenameFileOrDirectory(d.getSpaceType(), _type, srcObj.GetID(), newName, d.FamilyID, func(req *resty.Request) { + req.SetContext(ctx) + }) +} + +func (d *Wopan) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + dirList := make([]string, 0) + fileList := make([]string, 0) + if srcObj.IsDir() { + dirList = append(dirList, srcObj.GetID()) + } else { + fileList = append(fileList, srcObj.GetID()) + } + return d.client.CopyFile(dirList, fileList, dstDir.GetID(), + d.getSpaceType(), d.getSpaceType(), + d.FamilyID, d.FamilyID, func(req *resty.Request) { + req.SetContext(ctx) + }) +} + +func (d *Wopan) Remove(ctx context.Context, obj model.Obj) error { + dirList := make([]string, 0) + fileList := make([]string, 0) + if obj.IsDir() { + dirList = append(dirList, obj.GetID()) + } else { + fileList = append(fileList, obj.GetID()) + } + return d.client.DeleteFile(d.getSpaceType(), dirList, fileList, func(req *resty.Request) { + req.SetContext(ctx) + }) +} + +func (d *Wopan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + _, err := d.client.Upload2C(d.getSpaceType(), wopan.Upload2CFile{ + Name: stream.GetName(), + Size: stream.GetSize(), + Content: stream, + ContentType: stream.GetMimetype(), + }, dstDir.GetID(), d.FamilyID, wopan.Upload2COption{ + OnProgress: func(current, total int64) { + up(100 * float64(current) / float64(total)) + }, + }) + return err +} + +//func (d *Wopan) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { +// return nil, errs.NotSupport +//} + +var _ driver.Driver = (*Wopan)(nil) diff --git a/drivers/wopan/meta.go b/drivers/wopan/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..364eca918a309b97089dc1aa2bf17fc46df8e0ab --- /dev/null +++ b/drivers/wopan/meta.go @@ -0,0 +1,37 @@ +package template + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + // Usually one of two + driver.RootID + // define other + RefreshToken string `json:"refresh_token" required:"true"` + FamilyID string `json:"family_id" help:"Keep it empty if you want to use your personal drive"` + SortRule string `json:"sort_rule" type:"select" options:"name_asc,name_desc,time_asc,time_desc,size_asc,size_desc" default:"name_asc"` + + AccessToken string `json:"access_token"` +} + +var config = driver.Config{ + Name: "WoPan", + LocalSort: false, + OnlyLocal: false, + OnlyProxy: false, + NoCache: false, + NoUpload: false, + NeedMs: false, + DefaultRoot: "0", + CheckStatus: false, + Alert: "", + NoOverwriteUpload: true, +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &Wopan{} + }) +} diff --git a/drivers/wopan/types.go b/drivers/wopan/types.go new file mode 100644 index 0000000000000000000000000000000000000000..4025dbab237aaba669caa97b33c1aef37125d292 --- /dev/null +++ b/drivers/wopan/types.go @@ -0,0 +1,34 @@ +package template + +import ( + "github.com/alist-org/alist/v3/internal/model" + "github.com/xhofe/wopan-sdk-go" +) + +type Object struct { + model.ObjThumb + FID string +} + +func fileToObj(file wopan.File) (model.Obj, error) { + t, err := getTime(file.CreateTime) + if err != nil { + return nil, err + } + return &Object{ + ObjThumb: model.ObjThumb{ + Object: model.Object{ + ID: file.Id, + //Path: "", + Name: file.Name, + Size: file.Size, + Modified: t, + IsFolder: file.Type == 0, + }, + Thumbnail: model.Thumbnail{ + Thumbnail: file.ThumbUrl, + }, + }, + FID: file.Fid, + }, nil +} diff --git a/drivers/wopan/util.go b/drivers/wopan/util.go new file mode 100644 index 0000000000000000000000000000000000000000..b825d6ea9ac9edd0d0d806a1885d8113d36f43a5 --- /dev/null +++ b/drivers/wopan/util.go @@ -0,0 +1,40 @@ +package template + +import ( + "time" + + "github.com/xhofe/wopan-sdk-go" +) + +// do others that not defined in Driver interface + +func (d *Wopan) getSortRule() int { + switch d.SortRule { + case "name_asc": + return wopan.SortNameAsc + case "name_desc": + return wopan.SortNameDesc + case "time_asc": + return wopan.SortTimeAsc + case "time_desc": + return wopan.SortTimeDesc + case "size_asc": + return wopan.SortSizeAsc + case "size_desc": + return wopan.SortSizeDesc + default: + return wopan.SortNameAsc + } +} + +func (d *Wopan) getSpaceType() string { + if d.FamilyID == "" { + return wopan.SpaceTypePersonal + } + return wopan.SpaceTypeFamily +} + +// 20230607214351 +func getTime(str string) (time.Time, error) { + return time.Parse("20060102150405", str) +} diff --git a/drivers/yandex_disk/driver.go b/drivers/yandex_disk/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..5af9f2e4fb0b6b6c5748e6274437b23d10e7f347 --- /dev/null +++ b/drivers/yandex_disk/driver.go @@ -0,0 +1,132 @@ +package yandex_disk + +import ( + "context" + "net/http" + "path" + "strconv" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-resty/resty/v2" +) + +type YandexDisk struct { + model.Storage + Addition + AccessToken string +} + +func (d *YandexDisk) Config() driver.Config { + return config +} + +func (d *YandexDisk) GetAddition() driver.Additional { + return &d.Addition +} + +func (d *YandexDisk) Init(ctx context.Context) error { + return d.refreshToken() +} + +func (d *YandexDisk) Drop(ctx context.Context) error { + return nil +} + +func (d *YandexDisk) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) { + files, err := d.getFiles(dir.GetPath()) + if err != nil { + return nil, err + } + return utils.SliceConvert(files, func(src File) (model.Obj, error) { + return fileToObj(src), nil + }) +} + +func (d *YandexDisk) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) { + var resp DownResp + _, err := d.request("/download", http.MethodGet, func(req *resty.Request) { + req.SetQueryParam("path", file.GetPath()) + }, &resp) + if err != nil { + return nil, err + } + link := model.Link{ + URL: resp.Href, + } + return &link, nil +} + +func (d *YandexDisk) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { + _, err := d.request("", http.MethodPut, func(req *resty.Request) { + req.SetQueryParam("path", path.Join(parentDir.GetPath(), dirName)) + }, nil) + return err +} + +func (d *YandexDisk) Move(ctx context.Context, srcObj, dstDir model.Obj) error { + _, err := d.request("/move", http.MethodPost, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "from": srcObj.GetPath(), + "path": path.Join(dstDir.GetPath(), srcObj.GetName()), + "overwrite": "true", + }) + }, nil) + return err +} + +func (d *YandexDisk) Rename(ctx context.Context, srcObj model.Obj, newName string) error { + _, err := d.request("/move", http.MethodPost, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "from": srcObj.GetPath(), + "path": path.Join(path.Dir(srcObj.GetPath()), newName), + "overwrite": "true", + }) + }, nil) + return err +} + +func (d *YandexDisk) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { + _, err := d.request("/copy", http.MethodPost, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "from": srcObj.GetPath(), + "path": path.Join(dstDir.GetPath(), srcObj.GetName()), + "overwrite": "true", + }) + }, nil) + return err +} + +func (d *YandexDisk) Remove(ctx context.Context, obj model.Obj) error { + _, err := d.request("", http.MethodDelete, func(req *resty.Request) { + req.SetQueryParam("path", obj.GetPath()) + }, nil) + return err +} + +func (d *YandexDisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { + var resp UploadResp + _, err := d.request("/upload", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(map[string]string{ + "path": path.Join(dstDir.GetPath(), stream.GetName()), + "overwrite": "true", + }) + }, &resp) + if err != nil { + return err + } + req, err := http.NewRequest(resp.Method, resp.Href, stream) + if err != nil { + return err + } + req = req.WithContext(ctx) + req.Header.Set("Content-Length", strconv.FormatInt(stream.GetSize(), 10)) + req.Header.Set("Content-Type", "application/octet-stream") + res, err := base.HttpClient.Do(req) + _ = res.Body.Close() + return err +} + +var _ driver.Driver = (*YandexDisk)(nil) diff --git a/drivers/yandex_disk/meta.go b/drivers/yandex_disk/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..0a1fabdcbe4cd825890736ad85e24eb9880c4424 --- /dev/null +++ b/drivers/yandex_disk/meta.go @@ -0,0 +1,26 @@ +package yandex_disk + +import ( + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/op" +) + +type Addition struct { + RefreshToken string `json:"refresh_token" required:"true"` + OrderBy string `json:"order_by" type:"select" options:"name,path,created,modified,size" default:"name"` + OrderDirection string `json:"order_direction" type:"select" options:"asc,desc" default:"asc"` + driver.RootPath + ClientID string `json:"client_id" required:"true" default:"a78d5a69054042fa936f6c77f9a0ae8b"` + ClientSecret string `json:"client_secret" required:"true" default:"9c119bbb04b346d2a52aa64401936b2b"` +} + +var config = driver.Config{ + Name: "YandexDisk", + DefaultRoot: "/", +} + +func init() { + op.RegisterDriver(func() driver.Driver { + return &YandexDisk{} + }) +} diff --git a/drivers/yandex_disk/types.go b/drivers/yandex_disk/types.go new file mode 100644 index 0000000000000000000000000000000000000000..111bd464410f3974de5d93a4e291ba218460d581 --- /dev/null +++ b/drivers/yandex_disk/types.go @@ -0,0 +1,87 @@ +package yandex_disk + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type TokenErrResp struct { + ErrorDescription string `json:"error_description"` + Error string `json:"error"` +} + +type ErrResp struct { + Message string `json:"message"` + Description string `json:"description"` + Error string `json:"error"` +} + +type File struct { + //AntivirusStatus string `json:"antivirus_status"` + Size int64 `json:"size"` + //CommentIds struct { + // PrivateResource string `json:"private_resource"` + // PublicResource string `json:"public_resource"` + //} `json:"comment_ids"` + Name string `json:"name"` + //Exif struct { + // DateTime time.Time `json:"date_time"` + //} `json:"exif"` + //Created time.Time `json:"created"` + //ResourceId string `json:"resource_id"` + Modified time.Time `json:"modified"` + //MimeType string `json:"mime_type"` + File string `json:"file"` + //MediaType string `json:"media_type"` + Preview string `json:"preview"` + Path string `json:"path"` + //Sha256 string `json:"sha256"` + Type string `json:"type"` + //Md5 string `json:"md5"` + //Revision int64 `json:"revision"` +} + +func fileToObj(f File) model.Obj { + return &model.Object{ + Name: f.Name, + Size: f.Size, + Modified: f.Modified, + IsFolder: f.Type == "dir", + } +} + +type FilesResp struct { + Embedded struct { + Sort string `json:"sort"` + Items []File `json:"items"` + Limit int `json:"limit"` + Offset int `json:"offset"` + Path string `json:"path"` + Total int `json:"total"` + } `json:"_embedded"` + Name string `json:"name"` + Exif struct { + } `json:"exif"` + ResourceId string `json:"resource_id"` + Created time.Time `json:"created"` + Modified time.Time `json:"modified"` + Path string `json:"path"` + CommentIds struct { + } `json:"comment_ids"` + Type string `json:"type"` + Revision int64 `json:"revision"` +} + +type DownResp struct { + Href string `json:"href"` + Method string `json:"method"` + Templated bool `json:"templated"` +} + +type UploadResp struct { + OperationId string `json:"operation_id"` + Href string `json:"href"` + Method string `json:"method"` + Templated bool `json:"templated"` +} diff --git a/drivers/yandex_disk/util.go b/drivers/yandex_disk/util.go new file mode 100644 index 0000000000000000000000000000000000000000..c3ffc29542623ca44364ea0305a38f117d98e5a8 --- /dev/null +++ b/drivers/yandex_disk/util.go @@ -0,0 +1,98 @@ +package yandex_disk + +import ( + "errors" + "fmt" + "net/http" + "strconv" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/op" + "github.com/go-resty/resty/v2" +) + +// do others that not defined in Driver interface + +func (d *YandexDisk) refreshToken() error { + u := "https://oauth.yandex.com/token" + var resp base.TokenResp + var e TokenErrResp + _, err := base.RestyClient.R().SetResult(&resp).SetError(&e).SetFormData(map[string]string{ + "grant_type": "refresh_token", + "refresh_token": d.RefreshToken, + "client_id": d.ClientID, + "client_secret": d.ClientSecret, + }).Post(u) + if err != nil { + return err + } + if e.Error != "" { + return fmt.Errorf("%s : %s", e.Error, e.ErrorDescription) + } + d.AccessToken, d.RefreshToken = resp.AccessToken, resp.RefreshToken + op.MustSaveDriverStorage(d) + return nil +} + +func (d *YandexDisk) request(pathname string, method string, callback base.ReqCallback, resp interface{}) ([]byte, error) { + u := "https://cloud-api.yandex.net/v1/disk/resources" + pathname + req := base.RestyClient.R() + req.SetHeader("Authorization", "OAuth "+d.AccessToken) + if callback != nil { + callback(req) + } + if resp != nil { + req.SetResult(resp) + } + var e ErrResp + req.SetError(&e) + res, err := req.Execute(method, u) + if err != nil { + return nil, err + } + //log.Debug(res.String()) + if e.Error != "" { + if e.Error == "UnauthorizedError" { + err = d.refreshToken() + if err != nil { + return nil, err + } + return d.request(pathname, method, callback, resp) + } + return nil, errors.New(e.Description) + } + return res.Body(), nil +} + +func (d *YandexDisk) getFiles(path string) ([]File, error) { + limit := 100 + page := 1 + res := make([]File, 0) + for { + offset := (page - 1) * limit + query := map[string]string{ + "path": path, + "limit": strconv.Itoa(limit), + "offset": strconv.Itoa(offset), + } + if d.OrderBy != "" { + if d.OrderDirection == "desc" { + query["sort"] = "-" + d.OrderBy + } else { + query["sort"] = d.OrderBy + } + } + var resp FilesResp + _, err := d.request("", http.MethodGet, func(req *resty.Request) { + req.SetQueryParams(query) + }, &resp) + if err != nil { + return nil, err + } + res = append(res, resp.Embedded.Items...) + if resp.Embedded.Total <= offset+limit { + break + } + } + return res, nil +} diff --git a/internal/authn/authn.go b/internal/authn/authn.go new file mode 100644 index 0000000000000000000000000000000000000000..ea621d048af5c88153e76ba1608cd9916a78bd86 --- /dev/null +++ b/internal/authn/authn.go @@ -0,0 +1,26 @@ +package authn + +import ( + "fmt" + "net/http" + "net/url" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/server/common" + "github.com/go-webauthn/webauthn/webauthn" +) + +func NewAuthnInstance(r *http.Request) (*webauthn.WebAuthn, error) { + siteUrl, err := url.Parse(common.GetApiUrl(r)) + if err != nil { + return nil, err + } + return webauthn.New(&webauthn.Config{ + RPDisplayName: setting.GetStr(conf.SiteTitle), + RPID: siteUrl.Hostname(), + //RPOrigin: siteUrl.String(), + RPOrigins: []string{fmt.Sprintf("%s://%s", siteUrl.Scheme, siteUrl.Host)}, + // RPOrigin: "http://localhost:5173" + }) +} diff --git a/internal/bootstrap/config.go b/internal/bootstrap/config.go new file mode 100644 index 0000000000000000000000000000000000000000..bf454a4b71b79a11365b23b2850d6c9cdc014192 --- /dev/null +++ b/internal/bootstrap/config.go @@ -0,0 +1,118 @@ +package bootstrap + +import ( + "net/url" + "os" + "path/filepath" + "strings" + + "github.com/alist-org/alist/v3/cmd/flags" + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/caarlos0/env/v9" + log "github.com/sirupsen/logrus" +) + +func InitConfig() { + if flags.ForceBinDir { + if !filepath.IsAbs(flags.DataDir) { + ex, err := os.Executable() + if err != nil { + utils.Log.Fatal(err) + } + exPath := filepath.Dir(ex) + flags.DataDir = filepath.Join(exPath, flags.DataDir) + } + } + configPath := filepath.Join(flags.DataDir, "config.json") + log.Infof("reading config file: %s", configPath) + if !utils.Exists(configPath) { + log.Infof("config file not exists, creating default config file") + _, err := utils.CreateNestedFile(configPath) + if err != nil { + log.Fatalf("failed to create config file: %+v", err) + } + conf.Conf = conf.DefaultConfig() + if !utils.WriteJsonToFile(configPath, conf.Conf) { + log.Fatalf("failed to create default config file") + } + } else { + configBytes, err := os.ReadFile(configPath) + if err != nil { + log.Fatalf("reading config file error: %+v", err) + } + conf.Conf = conf.DefaultConfig() + err = utils.Json.Unmarshal(configBytes, conf.Conf) + if err != nil { + log.Fatalf("load config error: %+v", err) + } + // update config.json struct + confBody, err := utils.Json.MarshalIndent(conf.Conf, "", " ") + if err != nil { + log.Fatalf("marshal config error: %+v", err) + } + err = os.WriteFile(configPath, confBody, 0o777) + if err != nil { + log.Fatalf("update config struct error: %+v", err) + } + } + if !conf.Conf.Force { + confFromEnv() + } + // convert abs path + if !filepath.IsAbs(conf.Conf.TempDir) { + absPath, err := filepath.Abs(conf.Conf.TempDir) + if err != nil { + log.Fatalf("get abs path error: %+v", err) + } + conf.Conf.TempDir = absPath + } + err := os.RemoveAll(filepath.Join(conf.Conf.TempDir)) + if err != nil { + log.Errorln("failed delete temp file:", err) + } + err = os.MkdirAll(conf.Conf.TempDir, 0o777) + if err != nil { + log.Fatalf("create temp dir error: %+v", err) + } + log.Debugf("config: %+v", conf.Conf) + base.InitClient() + initURL() +} + +func confFromEnv() { + prefix := "ALIST_" + if flags.NoPrefix { + prefix = "" + } + log.Infof("load config from env with prefix: %s", prefix) + if err := env.ParseWithOptions(conf.Conf, env.Options{ + Prefix: prefix, + }); err != nil { + log.Fatalf("load config from env error: %+v", err) + } +} + +func initURL() { + if !strings.Contains(conf.Conf.SiteURL, "://") { + conf.Conf.SiteURL = utils.FixAndCleanPath(conf.Conf.SiteURL) + } + u, err := url.Parse(conf.Conf.SiteURL) + if err != nil { + utils.Log.Fatalf("can't parse site_url: %+v", err) + } + conf.URL = u +} + +func CleanTempDir() { + files, err := os.ReadDir(conf.Conf.TempDir) + if err != nil { + log.Errorln("failed list temp file: ", err) + } + for _, file := range files { + if err := os.RemoveAll(filepath.Join(conf.Conf.TempDir, file.Name())); err != nil { + log.Errorln("failed delete temp file: ", err) + } + } +} diff --git a/internal/bootstrap/data/data.go b/internal/bootstrap/data/data.go new file mode 100644 index 0000000000000000000000000000000000000000..c2170d2f47952f21b107677336f35fe27e5e9f60 --- /dev/null +++ b/internal/bootstrap/data/data.go @@ -0,0 +1,13 @@ +package data + +import "github.com/alist-org/alist/v3/cmd/flags" + +func InitData() { + initUser() + initSettings() + initTasks() + if flags.Dev { + initDevData() + initDevDo() + } +} diff --git a/internal/bootstrap/data/dev.go b/internal/bootstrap/data/dev.go new file mode 100644 index 0000000000000000000000000000000000000000..f6296c9e96a40768e82bac845850f30ccc3be0da --- /dev/null +++ b/internal/bootstrap/data/dev.go @@ -0,0 +1,55 @@ +package data + +import ( + "context" + + "github.com/alist-org/alist/v3/cmd/flags" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/message" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + log "github.com/sirupsen/logrus" +) + +func initDevData() { + _, err := op.CreateStorage(context.Background(), model.Storage{ + MountPath: "/", + Order: 0, + Driver: "Local", + Status: "", + Addition: `{"root_folder_path":"."}`, + }) + if err != nil { + log.Fatalf("failed to create storage: %+v", err) + } + err = db.CreateUser(&model.User{ + Username: "Noah", + Password: "hsu", + BasePath: "/data", + Role: 0, + Permission: 512, + }) + if err != nil { + log.Fatalf("failed to create user: %+v", err) + } +} + +func initDevDo() { + if flags.Dev { + go func() { + err := message.GetMessenger().WaitSend(message.Message{ + Type: "string", + Content: "dev mode", + }, 10) + if err != nil { + log.Debugf("%+v", err) + } + m, err := message.GetMessenger().WaitReceive(10) + if err != nil { + log.Debugf("%+v", err) + } else { + log.Debugf("received: %+v", m) + } + }() + } +} diff --git a/internal/bootstrap/data/setting.go b/internal/bootstrap/data/setting.go new file mode 100644 index 0000000000000000000000000000000000000000..93132c9973051b46cac6c658b23e67e3cbbe8c6c --- /dev/null +++ b/internal/bootstrap/data/setting.go @@ -0,0 +1,207 @@ +package data + +import ( + "github.com/alist-org/alist/v3/cmd/flags" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/offline_download/tool" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/pkg/utils/random" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +var initialSettingItems []model.SettingItem + +func initSettings() { + InitialSettings() + // check deprecated + settings, err := op.GetSettingItems() + if err != nil { + utils.Log.Fatalf("failed get settings: %+v", err) + } + for i := range settings { + if !isActive(settings[i].Key) && settings[i].Flag != model.DEPRECATED { + settings[i].Flag = model.DEPRECATED + err = op.SaveSettingItem(&settings[i]) + if err != nil { + utils.Log.Fatalf("failed save setting: %+v", err) + } + } + } + + // create or save setting + for i := range initialSettingItems { + item := &initialSettingItems[i] + if item.PreDefault == "" { + item.PreDefault = item.Value + } + // err + stored, err := op.GetSettingItemByKey(item.Key) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + utils.Log.Fatalf("failed get setting: %+v", err) + continue + } + // save + if stored != nil && item.Key != conf.VERSION && stored.Value != item.PreDefault { + item.Value = stored.Value + } + if stored == nil || *item != *stored { + err = op.SaveSettingItem(item) + if err != nil { + utils.Log.Fatalf("failed save setting: %+v", err) + } + } else { + // Not save so needs to execute hook + _, err = op.HandleSettingItemHook(item) + if err != nil { + utils.Log.Errorf("failed to execute hook on %s: %+v", item.Key, err) + } + } + } +} + +func isActive(key string) bool { + for _, item := range initialSettingItems { + if item.Key == key { + return true + } + } + return false +} + +func InitialSettings() []model.SettingItem { + var token string + if flags.Dev { + token = "dev_token" + } else { + token = random.Token() + } + initialSettingItems = []model.SettingItem{ + // site settings + {Key: conf.VERSION, Value: conf.Version, Type: conf.TypeString, Group: model.SITE, Flag: model.READONLY}, + //{Key: conf.ApiUrl, Value: "", Type: conf.TypeString, Group: model.SITE}, + //{Key: conf.BasePath, Value: "", Type: conf.TypeString, Group: model.SITE}, + {Key: conf.SiteTitle, Value: "AList", Type: conf.TypeString, Group: model.SITE}, + {Key: conf.Announcement, Value: "### repo\nhttps://github.com/alist-org/alist", Type: conf.TypeText, Group: model.SITE}, + {Key: "pagination_type", Value: "all", Type: conf.TypeSelect, Options: "all,pagination,load_more,auto_load_more", Group: model.SITE}, + {Key: "default_page_size", Value: "30", Type: conf.TypeNumber, Group: model.SITE}, + {Key: conf.AllowIndexed, Value: "false", Type: conf.TypeBool, Group: model.SITE}, + {Key: conf.AllowMounted, Value: "true", Type: conf.TypeBool, Group: model.SITE}, + {Key: conf.RobotsTxt, Value: "User-agent: *\nAllow: /", Type: conf.TypeText, Group: model.SITE}, + // style settings + {Key: conf.Logo, Value: "https://cdn.jsdelivr.net/gh/alist-org/logo@main/logo.svg", Type: conf.TypeText, Group: model.STYLE}, + {Key: conf.Favicon, Value: "https://cdn.jsdelivr.net/gh/alist-org/logo@main/logo.svg", Type: conf.TypeString, Group: model.STYLE}, + {Key: conf.MainColor, Value: "#1890ff", Type: conf.TypeString, Group: model.STYLE}, + {Key: "home_icon", Value: "🏠", Type: conf.TypeString, Group: model.STYLE}, + {Key: "home_container", Value: "max_980px", Type: conf.TypeSelect, Options: "max_980px,hope_container", Group: model.STYLE}, + {Key: "settings_layout", Value: "list", Type: conf.TypeSelect, Options: "list,responsive", Group: model.STYLE}, + // preview settings + {Key: conf.TextTypes, Value: "txt,htm,html,xml,java,properties,sql,js,md,json,conf,ini,vue,php,py,bat,gitignore,yml,go,sh,c,cpp,h,hpp,tsx,vtt,srt,ass,rs,lrc", Type: conf.TypeText, Group: model.PREVIEW, Flag: model.PRIVATE}, + {Key: conf.AudioTypes, Value: "mp3,flac,ogg,m4a,wav,opus,wma", Type: conf.TypeText, Group: model.PREVIEW, Flag: model.PRIVATE}, + {Key: conf.VideoTypes, Value: "mp4,mkv,avi,mov,rmvb,webm,flv,m3u8", Type: conf.TypeText, Group: model.PREVIEW, Flag: model.PRIVATE}, + {Key: "customize_players", Value: "VideoPlayer@@videoplayerapp://open?url=$durl", Type: conf.TypeText, Group: model.PREVIEW}, + {Key: conf.ImageTypes, Value: "jpg,tiff,jpeg,png,gif,bmp,svg,ico,swf,webp", Type: conf.TypeText, Group: model.PREVIEW, Flag: model.PRIVATE}, + //{Key: conf.OfficeTypes, Value: "doc,docx,xls,xlsx,ppt,pptx", Type: conf.TypeText, Group: model.PREVIEW, Flag: model.PRIVATE}, + {Key: conf.ProxyTypes, Value: "m3u8", Type: conf.TypeText, Group: model.PREVIEW, Flag: model.PRIVATE}, + {Key: conf.ProxyIgnoreHeaders, Value: "authorization,referer", Type: conf.TypeText, Group: model.PREVIEW, Flag: model.PRIVATE}, + {Key: "external_previews", Value: `{}`, Type: conf.TypeText, Group: model.PREVIEW}, + {Key: "iframe_previews", Value: `{ + "doc,docx,xls,xlsx,ppt,pptx": { + "Microsoft":"https://view.officeapps.live.com/op/view.aspx?src=$e_url", + "Google":"https://docs.google.com/gview?url=$e_url&embedded=true" + }, + "pdf": { + "PDF.js":"https://alist-org.github.io/pdf.js/web/viewer.html?file=$e_url" + }, + "epub": { + "EPUB.js":"https://alist-org.github.io/static/epub.js/viewer.html?url=$e_url" + } +}`, Type: conf.TypeText, Group: model.PREVIEW}, + // {Key: conf.OfficeViewers, Value: `{ + // "Microsoft":"https://view.officeapps.live.com/op/view.aspx?src=$url", + // "Google":"https://docs.google.com/gview?url=$url&embedded=true", + //}`, Type: conf.TypeText, Group: model.PREVIEW}, + // {Key: conf.PdfViewers, Value: `{ + // "pdf.js":"https://alist-org.github.io/pdf.js/web/viewer.html?file=$url" + //}`, Type: conf.TypeText, Group: model.PREVIEW}, + {Key: "audio_cover", Value: "https://jsd.nn.ci/gh/alist-org/logo@main/logo.svg", Type: conf.TypeString, Group: model.PREVIEW}, + {Key: conf.AudioAutoplay, Value: "true", Type: conf.TypeBool, Group: model.PREVIEW}, + {Key: conf.VideoAutoplay, Value: "true", Type: conf.TypeBool, Group: model.PREVIEW}, + // global settings + {Key: conf.HideFiles, Value: "/\\/README.md/i", Type: conf.TypeText, Group: model.GLOBAL}, + {Key: "package_download", Value: "true", Type: conf.TypeBool, Group: model.GLOBAL}, + {Key: conf.CustomizeHead, PreDefault: ``, Type: conf.TypeText, Group: model.GLOBAL, Flag: model.PRIVATE}, + {Key: conf.CustomizeBody, Type: conf.TypeText, Group: model.GLOBAL, Flag: model.PRIVATE}, + {Key: conf.LinkExpiration, Value: "0", Type: conf.TypeNumber, Group: model.GLOBAL, Flag: model.PRIVATE}, + {Key: conf.SignAll, Value: "true", Type: conf.TypeBool, Group: model.GLOBAL, Flag: model.PRIVATE}, + {Key: conf.PrivacyRegs, Value: `(?:(?:\d|[1-9]\d|1\d\d|2[0-4]\d|25[0-5])\.){3}(?:\d|[1-9]\d|1\d\d|2[0-4]\d|25[0-5]) +([[:xdigit:]]{1,4}(?::[[:xdigit:]]{1,4}){7}|::|:(?::[[:xdigit:]]{1,4}){1,6}|[[:xdigit:]]{1,4}:(?::[[:xdigit:]]{1,4}){1,5}|(?:[[:xdigit:]]{1,4}:){2}(?::[[:xdigit:]]{1,4}){1,4}|(?:[[:xdigit:]]{1,4}:){3}(?::[[:xdigit:]]{1,4}){1,3}|(?:[[:xdigit:]]{1,4}:){4}(?::[[:xdigit:]]{1,4}){1,2}|(?:[[:xdigit:]]{1,4}:){5}:[[:xdigit:]]{1,4}|(?:[[:xdigit:]]{1,4}:){1,6}:) +(?U)access_token=(.*)&`, + Type: conf.TypeText, Group: model.GLOBAL, Flag: model.PRIVATE}, + {Key: conf.OcrApi, Value: "https://api.nn.ci/ocr/file/json", Type: conf.TypeString, Group: model.GLOBAL}, + {Key: conf.FilenameCharMapping, Value: `{"/": "|"}`, Type: conf.TypeText, Group: model.GLOBAL}, + {Key: conf.ForwardDirectLinkParams, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL}, + {Key: conf.IgnoreDirectLinkParams, Value: "sign,alist_ts", Type: conf.TypeString, Group: model.GLOBAL}, + {Key: conf.StorageGroups, Value: "sign,alist_ts", Type: conf.TypeString, Group: model.GLOBAL}, + {Key: conf.WebauthnLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.GLOBAL, Flag: model.PUBLIC}, + + // single settings + {Key: conf.Token, Value: token, Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE}, + {Key: conf.SearchIndex, Value: "none", Type: conf.TypeSelect, Options: "database,database_non_full_text,bleve,meilisearch,none", Group: model.INDEX}, + {Key: conf.AutoUpdateIndex, Value: "false", Type: conf.TypeBool, Group: model.INDEX}, + {Key: conf.IgnorePaths, Value: "", Type: conf.TypeText, Group: model.INDEX, Flag: model.PRIVATE, Help: `one path per line`}, + {Key: conf.MaxIndexDepth, Value: "20", Type: conf.TypeNumber, Group: model.INDEX, Flag: model.PRIVATE, Help: `max depth of index`}, + {Key: conf.IndexProgress, Value: "{}", Type: conf.TypeText, Group: model.SINGLE, Flag: model.PRIVATE}, + + // SSO settings + {Key: conf.SSOLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.SSO, Flag: model.PUBLIC}, + {Key: conf.SSOLoginPlatform, Type: conf.TypeSelect, Options: "Casdoor,Github,Microsoft,Google,Dingtalk,OIDC", Group: model.SSO, Flag: model.PUBLIC}, + {Key: conf.SSOClientId, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSOClientSecret, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSOOIDCUsernameKey, Value: "name", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSOOrganizationName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSOApplicationName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSOEndpointName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSOJwtPublicKey, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSOAutoRegister, Value: "false", Type: conf.TypeBool, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSODefaultDir, Value: "/", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSODefaultPermission, Value: "0", Type: conf.TypeNumber, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSOCompatibilityMode, Value: "false", Type: conf.TypeBool, Group: model.SSO, Flag: model.PUBLIC}, + + // NOTIFY settings + {Key: conf.NotifyEnabled, Value: "false", Type: conf.TypeBool, Group: model.NOTIFICATION, Flag: model.PUBLIC}, + {Key: conf.NotifyPlatform, Type: conf.TypeSelect, Options: "gotify,goCqHttpBot,serverChan,pushDeer,bark,telegramBot,dingtalkBot,weWorkBot,weWorkApp,aibotk,iGot,pushPlus,chat,email,lark,pushMe,chronocat,webhook,closed", Group: model.NOTIFICATION, Flag: model.PUBLIC}, + {Key: conf.NotifyValue, Type: conf.TypeText, Group: model.NOTIFICATION, Flag: model.PUBLIC}, + {Key: conf.NotifyOnCopySucceeded, Value: "true", Type: conf.TypeBool, Group: model.NOTIFICATION, Flag: model.PUBLIC}, + {Key: conf.NotifyOnCopyFailed, Value: "true", Type: conf.TypeBool, Group: model.NOTIFICATION, Flag: model.PUBLIC}, + {Key: conf.NotifyOnDownloadSucceeded, Value: "true", Type: conf.TypeBool, Group: model.NOTIFICATION, Flag: model.PUBLIC}, + {Key: conf.NotifyOnDownloadFailed, Value: "true", Type: conf.TypeBool, Group: model.NOTIFICATION, Flag: model.PUBLIC}, + + // ldap settings + {Key: conf.LdapLoginEnabled, Value: "false", Type: conf.TypeBool, Group: model.LDAP, Flag: model.PUBLIC}, + {Key: conf.LdapServer, Value: "", Type: conf.TypeString, Group: model.LDAP, Flag: model.PRIVATE}, + {Key: conf.LdapManagerDN, Value: "", Type: conf.TypeString, Group: model.LDAP, Flag: model.PRIVATE}, + {Key: conf.LdapManagerPassword, Value: "", Type: conf.TypeString, Group: model.LDAP, Flag: model.PRIVATE}, + {Key: conf.LdapUserSearchBase, Value: "", Type: conf.TypeString, Group: model.LDAP, Flag: model.PRIVATE}, + {Key: conf.LdapUserSearchFilter, Value: "(uid=%s)", Type: conf.TypeString, Group: model.LDAP, Flag: model.PRIVATE}, + {Key: conf.LdapDefaultDir, Value: "/", Type: conf.TypeString, Group: model.LDAP, Flag: model.PRIVATE}, + {Key: conf.LdapDefaultPermission, Value: "0", Type: conf.TypeNumber, Group: model.LDAP, Flag: model.PRIVATE}, + {Key: conf.LdapLoginTips, Value: "login with ldap", Type: conf.TypeString, Group: model.LDAP, Flag: model.PUBLIC}, + + //s3 settings + {Key: conf.S3AccessKeyId, Value: "", Type: conf.TypeString, Group: model.S3, Flag: model.PRIVATE}, + {Key: conf.S3SecretAccessKey, Value: "", Type: conf.TypeString, Group: model.S3, Flag: model.PRIVATE}, + {Key: conf.S3Buckets, Value: "[]", Type: conf.TypeString, Group: model.S3, Flag: model.PRIVATE}, + } + initialSettingItems = append(initialSettingItems, tool.Tools.Items()...) + if flags.Dev { + initialSettingItems = append(initialSettingItems, []model.SettingItem{ + {Key: "test_deprecated", Value: "test_value", Type: conf.TypeString, Flag: model.DEPRECATED}, + {Key: "test_options", Value: "a", Type: conf.TypeSelect, Options: "a,b,c"}, + {Key: "test_help", Type: conf.TypeString, Help: "this is a help message"}, + }...) + } + return initialSettingItems +} diff --git a/internal/bootstrap/data/task.go b/internal/bootstrap/data/task.go new file mode 100644 index 0000000000000000000000000000000000000000..7100e2e25c11c564c870a4ae7ca1cd5ab62c6a76 --- /dev/null +++ b/internal/bootstrap/data/task.go @@ -0,0 +1,29 @@ +package data + +import ( + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" +) + +var initialTaskItems []model.TaskItem + +func initTasks() { + InitialTasks() + + for i := range initialTaskItems { + item := &initialTaskItems[i] + taskitem, _ := db.GetTaskDataByType(item.Key) + if taskitem == nil { + db.CreateTaskData(item) + } + } +} + +func InitialTasks() []model.TaskItem { + initialTaskItems = []model.TaskItem{ + {Key: "copy", PersistData: "[]"}, + {Key: "download", PersistData: "[]"}, + {Key: "transfer", PersistData: "[]"}, + } + return initialTaskItems +} diff --git a/internal/bootstrap/data/user.go b/internal/bootstrap/data/user.go new file mode 100644 index 0000000000000000000000000000000000000000..3b71e4982069d4b9d353d8898231a3001b855cbd --- /dev/null +++ b/internal/bootstrap/data/user.go @@ -0,0 +1,101 @@ +package data + +import ( + "os" + + "github.com/alist-org/alist/v3/cmd/flags" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/pkg/utils/random" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +func initUser() { + admin, err := op.GetAdmin() + adminPassword := random.String(8) + envpass := os.Getenv("ALIST_ADMIN_PASSWORD") + if flags.Dev { + adminPassword = "admin" + } else if len(envpass) > 0 { + adminPassword = envpass + } + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + salt := random.String(16) + admin = &model.User{ + Username: "admin", + Salt: salt, + PwdHash: model.TwoHashPwd(adminPassword, salt), + Role: model.ADMIN, + BasePath: "/", + Authn: "[]", + } + if err := op.CreateUser(admin); err != nil { + panic(err) + } else { + utils.Log.Infof("Successfully created the admin user and the initial password is: %s", adminPassword) + } + } else { + utils.Log.Fatalf("[init user] Failed to get admin user: %v", err) + } + } + guest, err := op.GetGuest() + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + salt := random.String(16) + guest = &model.User{ + Username: "guest", + PwdHash: model.TwoHashPwd("guest", salt), + Salt: salt, + Role: model.GUEST, + BasePath: "/", + Permission: 0, + Disabled: true, + Authn: "[]", + } + if err := db.CreateUser(guest); err != nil { + utils.Log.Fatalf("[init user] Failed to create guest user: %v", err) + } + } else { + utils.Log.Fatalf("[init user] Failed to get guest user: %v", err) + } + } + hashPwdForOldVersion() + updateAuthnForOldVersion() +} + +func hashPwdForOldVersion() { + users, _, err := op.GetUsers(1, -1) + if err != nil { + utils.Log.Fatalf("[hash pwd for old version] failed get users: %v", err) + } + for i := range users { + user := users[i] + if user.PwdHash == "" { + user.SetPassword(user.Password) + user.Password = "" + if err := db.UpdateUser(&user); err != nil { + utils.Log.Fatalf("[hash pwd for old version] failed update user: %v", err) + } + } + } +} + +func updateAuthnForOldVersion() { + users, _, err := op.GetUsers(1, -1) + if err != nil { + utils.Log.Fatalf("[update authn for old version] failed get users: %v", err) + } + for i := range users { + user := users[i] + if user.Authn == "" { + user.Authn = "[]" + if err := db.UpdateUser(&user); err != nil { + utils.Log.Fatalf("[update authn for old version] failed update user: %v", err) + } + } + } +} diff --git a/internal/bootstrap/db.go b/internal/bootstrap/db.go new file mode 100644 index 0000000000000000000000000000000000000000..5dfa2820d18c3f7f8a70132367529dc0aa2dd0a9 --- /dev/null +++ b/internal/bootstrap/db.go @@ -0,0 +1,84 @@ +package bootstrap + +import ( + "fmt" + stdlog "log" + "strings" + "time" + + "github.com/alist-org/alist/v3/cmd/flags" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/db" + log "github.com/sirupsen/logrus" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" + "gorm.io/gorm/schema" +) + +func InitDB() { + logLevel := logger.Silent + if flags.Debug || flags.Dev { + logLevel = logger.Info + } + newLogger := logger.New( + stdlog.New(log.StandardLogger().Out, "\r\n", stdlog.LstdFlags), + logger.Config{ + SlowThreshold: time.Second, + LogLevel: logLevel, + IgnoreRecordNotFoundError: true, + Colorful: true, + }, + ) + gormConfig := &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + TablePrefix: conf.Conf.Database.TablePrefix, + }, + Logger: newLogger, + } + var dB *gorm.DB + var err error + if flags.Dev { + dB, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), gormConfig) + conf.Conf.Database.Type = "sqlite3" + } else { + database := conf.Conf.Database + switch database.Type { + case "sqlite3": + { + if !(strings.HasSuffix(database.DBFile, ".db") && len(database.DBFile) > 3) { + log.Fatalf("db name error.") + } + dB, err = gorm.Open(sqlite.Open(fmt.Sprintf("%s?_journal=WAL&_vacuum=incremental", + database.DBFile)), gormConfig) + } + case "mysql": + { + //[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] + dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local&tls=%s", + database.User, database.Password, database.Host, database.Port, database.Name, database.SSLMode) + if database.DSN != "" { + dsn = database.DSN + } + dB, err = gorm.Open(mysql.Open(dsn), gormConfig) + } + case "postgres": + { + dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=%s TimeZone=Asia/Shanghai", + database.Host, database.User, database.Password, database.Name, database.Port, database.SSLMode) + if database.DSN != "" { + dsn = database.DSN + } + dB, err = gorm.Open(postgres.Open(dsn), gormConfig) + } + default: + log.Fatalf("not supported database type: %s", database.Type) + } + } + if err != nil { + log.Fatalf("failed to connect database:%s", err.Error()) + } + db.Init(dB) +} diff --git a/internal/bootstrap/index.go b/internal/bootstrap/index.go new file mode 100644 index 0000000000000000000000000000000000000000..02796c7274becbdbcdaf78dea18d95995842db2e --- /dev/null +++ b/internal/bootstrap/index.go @@ -0,0 +1,18 @@ +package bootstrap + +import ( + "github.com/alist-org/alist/v3/internal/search" + log "github.com/sirupsen/logrus" +) + +func InitIndex() { + progress, err := search.Progress() + if err != nil { + log.Errorf("init index error: %+v", err) + return + } + if !progress.IsDone { + progress.IsDone = true + search.WriteProgress(progress) + } +} diff --git a/internal/bootstrap/log.go b/internal/bootstrap/log.go new file mode 100644 index 0000000000000000000000000000000000000000..00411e5e1890826d5212894e878e6aa3e123a588 --- /dev/null +++ b/internal/bootstrap/log.go @@ -0,0 +1,56 @@ +package bootstrap + +import ( + "io" + "log" + "os" + + "github.com/alist-org/alist/v3/cmd/flags" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/natefinch/lumberjack" + "github.com/sirupsen/logrus" +) + +func init() { + formatter := logrus.TextFormatter{ + ForceColors: true, + EnvironmentOverrideColors: true, + TimestampFormat: "2006-01-02 15:04:05", + FullTimestamp: true, + } + logrus.SetFormatter(&formatter) + utils.Log.SetFormatter(&formatter) + // logrus.SetLevel(logrus.DebugLevel) +} + +func setLog(l *logrus.Logger) { + if flags.Debug || flags.Dev { + l.SetLevel(logrus.DebugLevel) + l.SetReportCaller(true) + } else { + l.SetLevel(logrus.InfoLevel) + l.SetReportCaller(false) + } +} + +func Log() { + setLog(logrus.StandardLogger()) + setLog(utils.Log) + logConfig := conf.Conf.Log + if logConfig.Enable { + var w io.Writer = &lumberjack.Logger{ + Filename: logConfig.Name, + MaxSize: logConfig.MaxSize, // megabytes + MaxBackups: logConfig.MaxBackups, + MaxAge: logConfig.MaxAge, //days + Compress: logConfig.Compress, // disabled by default + } + if flags.Debug || flags.Dev || flags.LogStd { + w = io.MultiWriter(os.Stdout, w) + } + logrus.SetOutput(w) + } + log.SetOutput(logrus.StandardLogger().Out) + utils.Log.Infof("init logrus...") +} diff --git a/internal/bootstrap/offline_download.go b/internal/bootstrap/offline_download.go new file mode 100644 index 0000000000000000000000000000000000000000..26e04071b1006c4f70d766aed0f8e6dbd6b49b88 --- /dev/null +++ b/internal/bootstrap/offline_download.go @@ -0,0 +1,17 @@ +package bootstrap + +import ( + "github.com/alist-org/alist/v3/internal/offline_download/tool" + "github.com/alist-org/alist/v3/pkg/utils" +) + +func InitOfflineDownloadTools() { + for k, v := range tool.Tools { + res, err := v.Init() + if err != nil { + utils.Log.Warnf("init tool %s failed: %s", k, err) + } else { + utils.Log.Infof("init tool %s success: %s", k, res) + } + } +} diff --git a/internal/bootstrap/storage.go b/internal/bootstrap/storage.go new file mode 100644 index 0000000000000000000000000000000000000000..44af99c56e12e297de1d1222465fa9348db240d5 --- /dev/null +++ b/internal/bootstrap/storage.go @@ -0,0 +1,30 @@ +package bootstrap + +import ( + "context" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" +) + +func LoadStorages() { + storages, err := db.GetEnabledStorages() + if err != nil { + utils.Log.Fatalf("failed get enabled storages: %+v", err) + } + go func(storages []model.Storage) { + for i := range storages { + err := op.LoadStorage(context.Background(), storages[i]) + if err != nil { + utils.Log.Errorf("failed get enabled storages: %+v", err) + } else { + utils.Log.Infof("success load storage: [%s], driver: [%s]", + storages[i].MountPath, storages[i].Driver) + } + } + conf.StoragesLoaded = true + }(storages) +} diff --git a/internal/bootstrap/task.go b/internal/bootstrap/task.go new file mode 100644 index 0000000000000000000000000000000000000000..3672f4daa8ada71fdb2e35cf88b7f39e8979de34 --- /dev/null +++ b/internal/bootstrap/task.go @@ -0,0 +1,64 @@ +package bootstrap + +import ( + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/offline_download/tool" + "github.com/alist-org/alist/v3/pkg/tache" +) + +func InitTaskManager() { + fs.UploadTaskManager = tache.NewManager[*fs.UploadTask](tache.WithWorks(conf.Conf.Tasks.Upload.Workers), tache.WithMaxRetry(conf.Conf.Tasks.Upload.MaxRetry)) //upload will not support persist + fs.CopyTaskManager = tache.NewManager[*fs.CopyTask](tache.WithWorks(conf.Conf.Tasks.Copy.Workers), tache.WithPersistFunction(db.GetTaskDataFunc("copy", conf.Conf.Tasks.Copy.TaskPersistant), db.UpdateTaskDataFunc("copy", conf.Conf.Tasks.Copy.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Copy.MaxRetry)) + tool.DownloadTaskManager = tache.NewManager[*tool.DownloadTask](tache.WithWorks(conf.Conf.Tasks.Download.Workers), tache.WithPersistFunction(db.GetTaskDataFunc("download", conf.Conf.Tasks.Download.TaskPersistant), db.UpdateTaskDataFunc("download", conf.Conf.Tasks.Download.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Download.MaxRetry)) + tool.TransferTaskManager = tache.NewManager[*tool.TransferTask](tache.WithWorks(conf.Conf.Tasks.Transfer.Workers), tache.WithPersistFunction(db.GetTaskDataFunc("transfer", conf.Conf.Tasks.Transfer.TaskPersistant), db.UpdateTaskDataFunc("transfer", conf.Conf.Tasks.Transfer.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Transfer.MaxRetry)) + if len(tool.TransferTaskManager.GetAll()) == 0 { //prevent offline downloaded files from being deleted + CleanTempDir() + } +} + +// func InitTaskManager() { + +// uploadTaskPersistPath := conf.Conf.Tasks.Upload.PersistPath +// copyTaskPersistPath := conf.Conf.Tasks.Copy.PersistPath +// downloadTaskPersistPath := conf.Conf.Tasks.Download.PersistPath +// transferTaskPersistPath := conf.Conf.Tasks.Transfer.PersistPath +// if !utils.Exists(uploadTaskPersistPath) { +// log.Infof("传输任务持久化文件") +// _, err := utils.CreateNestedFile(uploadTaskPersistPath) +// if err != nil { +// log.Fatalf("创建上传任务文件失败: %+v", err) +// } +// } + +// if !utils.Exists(copyTaskPersistPath) { +// log.Infof("复制任务持久化文件") +// _, err := utils.CreateNestedFile(copyTaskPersistPath) +// if err != nil { +// log.Fatalf("创建复制任务文件失败: %+v", err) +// } + +// } + +// if !utils.Exists(downloadTaskPersistPath) { +// log.Infof("下载任务持久化文件") +// _, err := utils.CreateNestedFile(downloadTaskPersistPath) +// if err != nil { +// log.Fatalf("创建下载任务文件失败: %+v", err) +// } +// } + +// if !utils.Exists(transferTaskPersistPath) { +// log.Infof("传输任务持久化文件") +// _, err := utils.CreateNestedFile(transferTaskPersistPath) +// if err != nil { +// log.Fatalf("创建传输任务文件失败: %+v", err) +// } +// } + +// fs.UploadTaskManager = tache.NewManager[*fs.UploadTask](tache.WithWorks(conf.Conf.Tasks.Upload.Workers), tache.WithPersistPath(uploadTaskPersistPath), tache.WithMaxRetry(conf.Conf.Tasks.Upload.MaxRetry)) +// fs.CopyTaskManager = tache.NewManager[*fs.CopyTask](tache.WithWorks(conf.Conf.Tasks.Copy.Workers), tache.WithPersistPath(copyTaskPersistPath), tache.WithMaxRetry(conf.Conf.Tasks.Copy.MaxRetry)) +// tool.DownloadTaskManager = tache.NewManager[*tool.DownloadTask](tache.WithWorks(conf.Conf.Tasks.Download.Workers), tache.WithPersistPath(downloadTaskPersistPath), tache.WithMaxRetry(conf.Conf.Tasks.Download.MaxRetry)) +// tool.TransferTaskManager = tache.NewManager[*tool.TransferTask](tache.WithWorks(conf.Conf.Tasks.Transfer.Workers), tache.WithPersistPath(transferTaskPersistPath), tache.WithMaxRetry(conf.Conf.Tasks.Transfer.MaxRetry)) +// } diff --git a/internal/conf/config.go b/internal/conf/config.go new file mode 100644 index 0000000000000000000000000000000000000000..28385f42195a34feed4cb7f330e0af1f31041b66 --- /dev/null +++ b/internal/conf/config.go @@ -0,0 +1,175 @@ +package conf + +import ( + "path/filepath" + + "github.com/alist-org/alist/v3/cmd/flags" + "github.com/alist-org/alist/v3/pkg/utils/random" +) + +type Database struct { + Type string `json:"type" env:"TYPE"` + Host string `json:"host" env:"HOST"` + Port int `json:"port" env:"PORT"` + User string `json:"user" env:"USER"` + Password string `json:"password" env:"PASS"` + Name string `json:"name" env:"NAME"` + DBFile string `json:"db_file" env:"FILE"` + TablePrefix string `json:"table_prefix" env:"TABLE_PREFIX"` + SSLMode string `json:"ssl_mode" env:"SSL_MODE"` + DSN string `json:"dsn" env:"DSN"` +} + +type Meilisearch struct { + Host string `json:"host" env:"HOST"` + APIKey string `json:"api_key" env:"API_KEY"` + IndexPrefix string `json:"index_prefix" env:"INDEX_PREFIX"` +} + +type Scheme struct { + Address string `json:"address" env:"ADDR"` + HttpPort int `json:"http_port" env:"HTTP_PORT"` + HttpsPort int `json:"https_port" env:"HTTPS_PORT"` + ForceHttps bool `json:"force_https" env:"FORCE_HTTPS"` + CertFile string `json:"cert_file" env:"CERT_FILE"` + KeyFile string `json:"key_file" env:"KEY_FILE"` + UnixFile string `json:"unix_file" env:"UNIX_FILE"` + UnixFilePerm string `json:"unix_file_perm" env:"UNIX_FILE_PERM"` +} + +type LogConfig struct { + Enable bool `json:"enable" env:"LOG_ENABLE"` + Name string `json:"name" env:"LOG_NAME"` + MaxSize int `json:"max_size" env:"MAX_SIZE"` + MaxBackups int `json:"max_backups" env:"MAX_BACKUPS"` + MaxAge int `json:"max_age" env:"MAX_AGE"` + Compress bool `json:"compress" env:"COMPRESS"` +} + +type TaskConfig struct { + Workers int `json:"workers" env:"WORKERS"` + MaxRetry int `json:"max_retry" env:"MAX_RETRY"` + PersistPath string `json:"persist_path" env:"PERSISTPATH"` + TaskPersistant bool `json:"task_persistant" env:"TASK_PERSISTANT"` +} + +type TasksConfig struct { + Download TaskConfig `json:"download" envPrefix:"DOWNLOAD_"` + Transfer TaskConfig `json:"transfer" envPrefix:"TRANSFER_"` + Upload TaskConfig `json:"upload" envPrefix:"UPLOAD_"` + Copy TaskConfig `json:"copy" envPrefix:"COPY_"` +} + +type Cors struct { + AllowOrigins []string `json:"allow_origins" env:"ALLOW_ORIGINS"` + AllowMethods []string `json:"allow_methods" env:"ALLOW_METHODS"` + AllowHeaders []string `json:"allow_headers" env:"ALLOW_HEADERS"` +} + +type S3 struct { + Enable bool `json:"enable" env:"ENABLE"` + Port int `json:"port" env:"PORT"` + SSL bool `json:"ssl" env:"SSL"` +} + +type Config struct { + Force bool `json:"force" env:"FORCE"` + Notify bool `json:"notify" env:"NOTIFY"` + SiteURL string `json:"site_url" env:"SITE_URL"` + Cdn string `json:"cdn" env:"CDN"` + JwtSecret string `json:"jwt_secret" env:"JWT_SECRET"` + TokenExpiresIn int `json:"token_expires_in" env:"TOKEN_EXPIRES_IN"` + Database Database `json:"database" envPrefix:"DB_"` + Meilisearch Meilisearch `json:"meilisearch" envPrefix:"MEILISEARCH_"` + Scheme Scheme `json:"scheme"` + TempDir string `json:"temp_dir" env:"TEMP_DIR"` + BleveDir string `json:"bleve_dir" env:"BLEVE_DIR"` + DistDir string `json:"dist_dir"` + Log LogConfig `json:"log"` + DelayedStart int `json:"delayed_start" env:"DELAYED_START"` + MaxConnections int `json:"max_connections" env:"MAX_CONNECTIONS"` + TlsInsecureSkipVerify bool `json:"tls_insecure_skip_verify" env:"TLS_INSECURE_SKIP_VERIFY"` + Tasks TasksConfig `json:"tasks" envPrefix:"TASKS_"` + Cors Cors `json:"cors" envPrefix:"CORS_"` + S3 S3 `json:"s3" envPrefix:"S3_"` +} + +func DefaultConfig() *Config { + tempDir := filepath.Join(flags.DataDir, "temp") + indexDir := filepath.Join(flags.DataDir, "bleve") + logPath := filepath.Join(flags.DataDir, "log/log.log") + dbPath := filepath.Join(flags.DataDir, "data.db") + downloadPersistPath := filepath.Join(flags.DataDir, "tasks/download.json") + transferPersistPath := filepath.Join(flags.DataDir, "tasks/transfer.json") + uploadPersistPath := filepath.Join(flags.DataDir, "tasks/upload.json") + copyPersistPath := filepath.Join(flags.DataDir, "tasks/copy.json") + return &Config{ + Scheme: Scheme{ + Address: "0.0.0.0", + UnixFile: "", + HttpPort: 5244, + HttpsPort: -1, + ForceHttps: false, + CertFile: "", + KeyFile: "", + }, + Notify: true, + JwtSecret: random.String(16), + TokenExpiresIn: 48, + TempDir: tempDir, + Database: Database{ + Type: "sqlite3", + Port: 0, + TablePrefix: "x_", + DBFile: dbPath, + }, + Meilisearch: Meilisearch{ + Host: "http://localhost:7700", + }, + BleveDir: indexDir, + Log: LogConfig{ + Enable: true, + Name: logPath, + MaxSize: 50, + MaxBackups: 30, + MaxAge: 28, + }, + MaxConnections: 0, + TlsInsecureSkipVerify: true, + Tasks: TasksConfig{ + Download: TaskConfig{ + Workers: 5, + MaxRetry: 1, + PersistPath: downloadPersistPath, + TaskPersistant: true, + }, + Transfer: TaskConfig{ + Workers: 5, + MaxRetry: 2, + PersistPath: transferPersistPath, + TaskPersistant: true, + }, + Upload: TaskConfig{ + Workers: 5, + PersistPath: uploadPersistPath, + TaskPersistant: true, + }, + Copy: TaskConfig{ + Workers: 5, + MaxRetry: 2, + PersistPath: copyPersistPath, + TaskPersistant: true, + }, + }, + Cors: Cors{ + AllowOrigins: []string{"*"}, + AllowMethods: []string{"*"}, + AllowHeaders: []string{"*"}, + }, + S3: S3{ + Enable: false, + Port: 5246, + SSL: false, + }, + } +} diff --git a/internal/conf/const.go b/internal/conf/const.go new file mode 100644 index 0000000000000000000000000000000000000000..dbf77f3059774a408ab346d4abbf54efedd63f92 --- /dev/null +++ b/internal/conf/const.go @@ -0,0 +1,120 @@ +package conf + +const ( + TypeString = "string" + TypeSelect = "select" + TypeBool = "bool" + TypeText = "text" + TypeNumber = "number" +) + +const ( + // site + VERSION = "version" + SiteTitle = "site_title" + Announcement = "announcement" + AllowIndexed = "allow_indexed" + AllowMounted = "allow_mounted" + RobotsTxt = "robots_txt" + + Logo = "logo" + Favicon = "favicon" + MainColor = "main_color" + + // preview + TextTypes = "text_types" + AudioTypes = "audio_types" + VideoTypes = "video_types" + ImageTypes = "image_types" + ProxyTypes = "proxy_types" + ProxyIgnoreHeaders = "proxy_ignore_headers" + AudioAutoplay = "audio_autoplay" + VideoAutoplay = "video_autoplay" + + // global + HideFiles = "hide_files" + CustomizeHead = "customize_head" + CustomizeBody = "customize_body" + LinkExpiration = "link_expiration" + SignAll = "sign_all" + PrivacyRegs = "privacy_regs" + OcrApi = "ocr_api" + FilenameCharMapping = "filename_char_mapping" + ForwardDirectLinkParams = "forward_direct_link_params" + IgnoreDirectLinkParams = "ignore_direct_link_params" + StorageGroups = "storage_groups" + WebauthnLoginEnabled = "webauthn_login_enabled" + + // index + SearchIndex = "search_index" + AutoUpdateIndex = "auto_update_index" + IgnorePaths = "ignore_paths" + MaxIndexDepth = "max_index_depth" + + // aria2 + Aria2Uri = "aria2_uri" + Aria2Secret = "aria2_secret" + + // single + Token = "token" + IndexProgress = "index_progress" + + //SSO + SSOClientId = "sso_client_id" + SSOClientSecret = "sso_client_secret" + SSOLoginEnabled = "sso_login_enabled" + SSOLoginPlatform = "sso_login_platform" + SSOOIDCUsernameKey = "sso_oidc_username_key" + SSOOrganizationName = "sso_organization_name" + SSOApplicationName = "sso_application_name" + SSOEndpointName = "sso_endpoint_name" + SSOJwtPublicKey = "sso_jwt_public_key" + SSOAutoRegister = "sso_auto_register" + SSODefaultDir = "sso_default_dir" + SSODefaultPermission = "sso_default_permission" + SSOCompatibilityMode = "sso_compatibility_mode" + + //NOTYFY + NotifyEnabled = "notify_enabled" + NotifyPlatform = "notify_platform" + NotifyValue = "notify_value" + NotifyOnCopySucceeded = "notify_on_copy_succeeded" + NotifyOnCopyFailed = "notify_on_copy_failed" + NotifyOnDownloadSucceeded = "notify_on_download_succeeded" + NotifyOnDownloadFailed = "notify_on_download_failed" + + //ldap + LdapLoginEnabled = "ldap_login_enabled" + LdapServer = "ldap_server" + LdapManagerDN = "ldap_manager_dn" + LdapManagerPassword = "ldap_manager_password" + LdapUserSearchBase = "ldap_user_search_base" + LdapUserSearchFilter = "ldap_user_search_filter" + LdapDefaultPermission = "ldap_default_permission" + LdapDefaultDir = "ldap_default_dir" + LdapLoginTips = "ldap_login_tips" + + //s3 + S3Buckets = "s3_buckets" + S3AccessKeyId = "s3_access_key_id" + S3SecretAccessKey = "s3_secret_access_key" + + // qbittorrent + QbittorrentUrl = "qbittorrent_url" + QbittorrentSeedtime = "qbittorrent_seedtime" +) + +const ( + UNKNOWN = iota + FOLDER + //OFFICE + VIDEO + AUDIO + TEXT + IMAGE +) + +// ContextKey is the type of context keys. +const ( + NoTaskKey = "no_task" +) diff --git a/internal/conf/var.go b/internal/conf/var.go new file mode 100644 index 0000000000000000000000000000000000000000..0a8eb16fcd15a9918951ee0728d108c892479a88 --- /dev/null +++ b/internal/conf/var.go @@ -0,0 +1,34 @@ +package conf + +import ( + "net/url" + "regexp" +) + +var ( + BuiltAt string + GoVersion string + GitAuthor string + GitCommit string + Version string = "dev" + WebVersion string +) + +var ( + Conf *Config + URL *url.URL +) + +var SlicesMap = make(map[string][]string) +var FilenameCharMap = make(map[string]string) +var PrivacyReg []*regexp.Regexp + +var ( + // StoragesLoaded loaded success if empty + StoragesLoaded = false +) +var ( + RawIndexHtml string + ManageHtml string + IndexHtml string +) diff --git a/internal/db/db.go b/internal/db/db.go new file mode 100644 index 0000000000000000000000000000000000000000..2df58d3760b017ab056977b0604fe8354716485f --- /dev/null +++ b/internal/db/db.go @@ -0,0 +1,47 @@ +package db + +import ( + log "github.com/sirupsen/logrus" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "gorm.io/gorm" +) + +var db *gorm.DB + +func Init(d *gorm.DB) { + db = d + err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem)) + if err != nil { + log.Fatalf("failed migrate database: %s", err.Error()) + } +} + +func AutoMigrate(dst ...interface{}) error { + var err error + if conf.Conf.Database.Type == "mysql" { + err = db.Set("gorm:table_options", "ENGINE=InnoDB CHARSET=utf8mb4").AutoMigrate(dst...) + } else { + err = db.AutoMigrate(dst...) + } + return err +} + +func GetDb() *gorm.DB { + return db +} + +func Close() { + log.Info("closing db") + sqlDB, err := db.DB() + if err != nil { + log.Errorf("failed to get db: %s", err.Error()) + return + } + err = sqlDB.Close() + if err != nil { + log.Errorf("failed to close db: %s", err.Error()) + return + } +} diff --git a/internal/db/meta.go b/internal/db/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..8b6a605e8098e8a04b9b76cf644b2064f966c647 --- /dev/null +++ b/internal/db/meta.go @@ -0,0 +1,45 @@ +package db + +import ( + "github.com/alist-org/alist/v3/internal/model" + "github.com/pkg/errors" +) + +func GetMetaByPath(path string) (*model.Meta, error) { + meta := model.Meta{Path: path} + if err := db.Where(meta).First(&meta).Error; err != nil { + return nil, errors.Wrapf(err, "failed select meta") + } + return &meta, nil +} + +func GetMetaById(id uint) (*model.Meta, error) { + var u model.Meta + if err := db.First(&u, id).Error; err != nil { + return nil, errors.Wrapf(err, "failed get old meta") + } + return &u, nil +} + +func CreateMeta(u *model.Meta) error { + return errors.WithStack(db.Create(u).Error) +} + +func UpdateMeta(u *model.Meta) error { + return errors.WithStack(db.Save(u).Error) +} + +func GetMetas(pageIndex, pageSize int) (metas []model.Meta, count int64, err error) { + metaDB := db.Model(&model.Meta{}) + if err = metaDB.Count(&count).Error; err != nil { + return nil, 0, errors.Wrapf(err, "failed get metas count") + } + if err = metaDB.Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&metas).Error; err != nil { + return nil, 0, errors.Wrapf(err, "failed get find metas") + } + return metas, count, nil +} + +func DeleteMetaById(id uint) error { + return errors.WithStack(db.Delete(&model.Meta{}, id).Error) +} diff --git a/internal/db/searchnode.go b/internal/db/searchnode.go new file mode 100644 index 0000000000000000000000000000000000000000..03251d4c50767a40ef2e927f0da8587257c0933e --- /dev/null +++ b/internal/db/searchnode.go @@ -0,0 +1,91 @@ +package db + +import ( + "fmt" + stdpath "path" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +func whereInParent(parent string) *gorm.DB { + if parent == "/" { + return db.Where("1 = 1") + } + return db.Where(fmt.Sprintf("%s LIKE ?", columnName("parent")), + fmt.Sprintf("%s/%%", parent)). + Or(fmt.Sprintf("%s = ?", columnName("parent")), parent) +} + +func CreateSearchNode(node *model.SearchNode) error { + return db.Create(node).Error +} + +func BatchCreateSearchNodes(nodes *[]model.SearchNode) error { + return db.CreateInBatches(nodes, 1000).Error +} + +func DeleteSearchNodesByParent(path string) error { + path = utils.FixAndCleanPath(path) + err := db.Where(whereInParent(path)).Delete(&model.SearchNode{}).Error + if err != nil { + return err + } + dir, name := stdpath.Split(path) + return db.Where(fmt.Sprintf("%s = ? AND %s = ?", + columnName("parent"), columnName("name")), + dir, name).Delete(&model.SearchNode{}).Error +} + +func ClearSearchNodes() error { + return db.Where("1 = 1").Delete(&model.SearchNode{}).Error +} + +func GetSearchNodesByParent(parent string) ([]model.SearchNode, error) { + var nodes []model.SearchNode + if err := db.Where(fmt.Sprintf("%s = ?", + columnName("parent")), parent).Find(&nodes).Error; err != nil { + return nil, err + } + return nodes, nil +} + +func SearchNode(req model.SearchReq, useFullText bool) ([]model.SearchNode, int64, error) { + var searchDB *gorm.DB + if !useFullText || conf.Conf.Database.Type == "sqlite3" { + keywordsClause := db.Where("1 = 1") + for _, keyword := range strings.Fields(req.Keywords) { + keywordsClause = keywordsClause.Where("name LIKE ?", fmt.Sprintf("%%%s%%", keyword)) + } + searchDB = db.Model(&model.SearchNode{}).Where(whereInParent(req.Parent)).Where(keywordsClause) + } else { + switch conf.Conf.Database.Type { + case "mysql": + searchDB = db.Model(&model.SearchNode{}).Where(whereInParent(req.Parent)). + Where("MATCH (name) AGAINST (? IN BOOLEAN MODE)", "'*"+req.Keywords+"*'") + case "postgres": + searchDB = db.Model(&model.SearchNode{}).Where(whereInParent(req.Parent)). + Where("to_tsvector(name) @@ to_tsquery(?)", strings.Join(strings.Fields(req.Keywords), " & ")) + } + } + + if req.Scope != 0 { + isDir := req.Scope == 1 + searchDB.Where(db.Where("is_dir = ?", isDir)) + } + + var count int64 + if err := searchDB.Count(&count).Error; err != nil { + return nil, 0, errors.Wrapf(err, "failed get search items count") + } + var files []model.SearchNode + if err := searchDB.Order("name asc").Offset((req.Page - 1) * req.PerPage).Limit(req.PerPage). + Find(&files).Error; err != nil { + return nil, 0, err + } + return files, count, nil +} diff --git a/internal/db/settingitem.go b/internal/db/settingitem.go new file mode 100644 index 0000000000000000000000000000000000000000..2ba0c665acd7c5cc39e0168e4ec54c03a8c4a9d3 --- /dev/null +++ b/internal/db/settingitem.go @@ -0,0 +1,68 @@ +package db + +import ( + "fmt" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/pkg/errors" +) + +func GetSettingItems() ([]model.SettingItem, error) { + var settingItems []model.SettingItem + if err := db.Find(&settingItems).Error; err != nil { + return nil, errors.WithStack(err) + } + return settingItems, nil +} + +func GetSettingItemByKey(key string) (*model.SettingItem, error) { + var settingItem model.SettingItem + if err := db.Where(fmt.Sprintf("%s = ?", columnName("key")), key).First(&settingItem).Error; err != nil { + return nil, errors.WithStack(err) + } + return &settingItem, nil +} + +// func GetSettingItemInKeys(keys []string) ([]model.SettingItem, error) { +// var settingItem []model.SettingItem +// if err := db.Where(fmt.Sprintf("%s in ?", columnName("key")), keys).Find(&settingItem).Error; err != nil { +// return nil, errors.WithStack(err) +// } +// return settingItem, nil +// } + +func GetPublicSettingItems() ([]model.SettingItem, error) { + var settingItems []model.SettingItem + if err := db.Where(fmt.Sprintf("%s in ?", columnName("flag")), []int{model.PUBLIC, model.READONLY}).Find(&settingItems).Error; err != nil { + return nil, errors.WithStack(err) + } + return settingItems, nil +} + +func GetSettingItemsByGroup(group int) ([]model.SettingItem, error) { + var settingItems []model.SettingItem + if err := db.Where(fmt.Sprintf("%s = ?", columnName("group")), group).Find(&settingItems).Error; err != nil { + return nil, errors.WithStack(err) + } + return settingItems, nil +} + +func GetSettingItemsInGroups(groups []int) ([]model.SettingItem, error) { + var settingItems []model.SettingItem + if err := db.Where(fmt.Sprintf("%s in ?", columnName("group")), groups).Find(&settingItems).Error; err != nil { + return nil, errors.WithStack(err) + } + return settingItems, nil +} + +func SaveSettingItems(items []model.SettingItem) (err error) { + return errors.WithStack(db.Save(items).Error) +} + +func SaveSettingItem(item *model.SettingItem) error { + return errors.WithStack(db.Save(item).Error) +} + +func DeleteSettingItemByKey(key string) error { + return errors.WithStack(db.Delete(&model.SettingItem{Key: key}).Error) +} diff --git a/internal/db/storage.go b/internal/db/storage.go new file mode 100644 index 0000000000000000000000000000000000000000..1639929117d41f5729815664e0f9d36d51118343 --- /dev/null +++ b/internal/db/storage.go @@ -0,0 +1,127 @@ +package db + +import ( + "fmt" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +// why don't need `cache` for storage? +// because all storage store in `op.storagesMap` +// the most of the read operation is from `op.storagesMap` +// just for persistence in database + +// CreateStorage just insert storage to database +func CreateStorage(storage *model.Storage) error { + return errors.WithStack(db.Create(storage).Error) +} + +// UpdateStorage just update storage in database +func UpdateStorage(storage *model.Storage) error { + return errors.WithStack(db.Save(storage).Error) +} + +// DeleteStorageById just delete storage from database by id +func DeleteStorageById(id uint) error { + return errors.WithStack(db.Delete(&model.Storage{}, id).Error) +} + +// GetStorages Get all storages from database order by index +func GetStorages(pageIndex, pageSize int) ([]model.Storage, int64, error) { + storageDB := db.Model(&model.Storage{}) + var count int64 + if err := storageDB.Count(&count).Error; err != nil { + return nil, 0, errors.Wrapf(err, "failed get storages count") + } + var storages []model.Storage + if err := storageDB.Order(columnName("order")).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&storages).Error; err != nil { + return nil, 0, errors.WithStack(err) + } + return storages, count, nil +} + +// GetStorageById Get Storage by id, used to update storage usually +func GetStorageById(id uint) (*model.Storage, error) { + var storage model.Storage + storage.ID = id + if err := db.First(&storage).Error; err != nil { + return nil, errors.WithStack(err) + } + return &storage, nil +} + +// GetStorageByMountPath Get Storage by mountPath, used to update storage usually +func GetStorageByMountPath(mountPath string) (*model.Storage, error) { + var storage model.Storage + if err := db.Where("mount_path = ?", mountPath).First(&storage).Error; err != nil { + return nil, errors.WithStack(err) + } + return &storage, nil +} + +func GetEnabledStorages() ([]model.Storage, error) { + var storages []model.Storage + if err := db.Where(fmt.Sprintf("%s = ?", columnName("disabled")), false).Find(&storages).Error; err != nil { + return nil, errors.WithStack(err) + } + return storages, nil +} + +func GetGroupStorages(groupName string) ([]model.Storage, error) { + var storages []model.Storage + if err := db.Where(fmt.Sprintf("%s = ?", columnName("group")), groupName).Find(&storages).Error; err != nil { + return nil, errors.WithStack(err) + } + return storages, nil +} + +func UpdateGroupStorages(groupName string, changedAdditions map[string]interface{}) error { + var storages []model.Storage + if err := db.Where(fmt.Sprintf("%s = ?", columnName("group")), groupName).Find(&storages).Error; err != nil { + return errors.WithStack(err) + } + // 动态构建 SQL 表达式 + ids := extractField(storages, func(u model.Storage) int { return int(u.ID) }) //提取同组存储的id组成数组 + + //方案一:更新数据为字段名:新值类型 + expr := "addition" + var args []interface{} + for key, val := range changedAdditions { + expr = fmt.Sprintf("JSON_SET(%s, ?, ?)", expr) + args = append(args, "$."+key, val) + } + updates := map[string]interface{}{ + "addition": gorm.Expr(expr, args...), + } + // 执行更新 + if updateErr := db.Model(&model.Storage{}).Where("id IN ?", ids).Updates(updates).Error; updateErr != nil { + return errors.WithStack(updateErr) + } + + // 方案二 更新数据为旧数据:新数据类型 + // var keys []string + // for oldStr := range changedAdditions { + // keys = append(keys, oldStr) + // } + // sort.Strings(keys) // 按字典序排序 + // // 按排序后的键遍历 + // expr := "addition" + // for _, oldStr := range keys { + // newStr := changedAdditions[oldStr] + // expr = fmt.Sprintf("REPLACE(%s, '%s', '%s')", expr, oldStr, newStr) + // } + // if updateErr := db.Model(&model.Storage{}).Where("id IN ?", ids).Update("addition", gorm.Expr(expr)).Error; updateErr != nil { + // return errors.WithStack(updateErr) + // } + return nil +} + +func extractField[T any, F any](slice []T, getter func(T) F) []F { + result := make([]F, 0, len(slice)) + for _, item := range slice { + result = append(result, getter(item)) + } + return result +} diff --git a/internal/db/tasks.go b/internal/db/tasks.go new file mode 100644 index 0000000000000000000000000000000000000000..9d2de1cff64ea8b1fc367e002833953461e3bcf2 --- /dev/null +++ b/internal/db/tasks.go @@ -0,0 +1,48 @@ +package db + +import ( + "github.com/alist-org/alist/v3/internal/model" + "github.com/pkg/errors" +) + +func GetTaskDataByType(type_s string) (*model.TaskItem, error) { + task := model.TaskItem{Key: type_s} + if err := db.Where(task).First(&task).Error; err != nil { + return nil, errors.Wrapf(err, "failed find task") + } + return &task, nil +} + +func UpdateTaskData(t *model.TaskItem) error { + return errors.WithStack(db.Model(&model.TaskItem{}).Where("key = ?", t.Key).Update("persist_data", t.PersistData).Error) +} + +func CreateTaskData(t *model.TaskItem) error { + return errors.WithStack(db.Create(t).Error) +} + +func GetTaskDataFunc(type_s string, enabled bool) func() ([]byte, error) { + if !enabled { + return nil + } + task, err := GetTaskDataByType(type_s) + if err != nil { + return nil + } + return func() ([]byte, error) { + return []byte(task.PersistData), nil + } +} + +func UpdateTaskDataFunc(type_s string, enabled bool) func([]byte) error { + if !enabled { + return nil + } + return func(data []byte) error { + s := string(data) + if s == "null" || s == "" { + s = "[]" + } + return UpdateTaskData(&model.TaskItem{Key: type_s, PersistData: s}) + } +} diff --git a/internal/db/user.go b/internal/db/user.go new file mode 100644 index 0000000000000000000000000000000000000000..822926664c91fed8071a999760491902a7145a3e --- /dev/null +++ b/internal/db/user.go @@ -0,0 +1,102 @@ +package db + +import ( + "encoding/base64" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/go-webauthn/webauthn/webauthn" + "github.com/pkg/errors" +) + +func GetUserByRole(role int) (*model.User, error) { + user := model.User{Role: role} + if err := db.Where(user).Take(&user).Error; err != nil { + return nil, err + } + return &user, nil +} + +func GetUserByName(username string) (*model.User, error) { + user := model.User{Username: username} + if err := db.Where(user).First(&user).Error; err != nil { + return nil, errors.Wrapf(err, "failed find user") + } + return &user, nil +} + +func GetUserBySSOID(ssoID string) (*model.User, error) { + user := model.User{SsoID: ssoID} + if err := db.Where(user).First(&user).Error; err != nil { + return nil, errors.Wrapf(err, "The single sign on platform is not bound to any users") + } + return &user, nil +} + +func GetUserById(id uint) (*model.User, error) { + var u model.User + if err := db.First(&u, id).Error; err != nil { + return nil, errors.Wrapf(err, "failed get old user") + } + return &u, nil +} + +func CreateUser(u *model.User) error { + return errors.WithStack(db.Create(u).Error) +} + +func UpdateUser(u *model.User) error { + return errors.WithStack(db.Save(u).Error) +} + +func GetUsers(pageIndex, pageSize int) (users []model.User, count int64, err error) { + userDB := db.Model(&model.User{}) + if err := userDB.Count(&count).Error; err != nil { + return nil, 0, errors.Wrapf(err, "failed get users count") + } + if err := userDB.Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&users).Error; err != nil { + return nil, 0, errors.Wrapf(err, "failed get find users") + } + return users, count, nil +} + +func DeleteUserById(id uint) error { + return errors.WithStack(db.Delete(&model.User{}, id).Error) +} + +func UpdateAuthn(userID uint, authn string) error { + return db.Model(&model.User{ID: userID}).Update("authn", authn).Error +} + +func RegisterAuthn(u *model.User, credential *webauthn.Credential) error { + if u == nil { + return errors.New("user is nil") + } + exists := u.WebAuthnCredentials() + if credential != nil { + exists = append(exists, *credential) + } + res, err := utils.Json.Marshal(exists) + if err != nil { + return err + } + return UpdateAuthn(u.ID, string(res)) +} + +func RemoveAuthn(u *model.User, id string) error { + exists := u.WebAuthnCredentials() + for i := 0; i < len(exists); i++ { + idEncoded := base64.StdEncoding.EncodeToString(exists[i].ID) + if idEncoded == id { + exists[len(exists)-1], exists[i] = exists[i], exists[len(exists)-1] + exists = exists[:len(exists)-1] + break + } + } + + res, err := utils.Json.Marshal(exists) + if err != nil { + return err + } + return UpdateAuthn(u.ID, string(res)) +} diff --git a/internal/db/util.go b/internal/db/util.go new file mode 100644 index 0000000000000000000000000000000000000000..38a06bcdacc43165cab085611f4c8f95d8f386bc --- /dev/null +++ b/internal/db/util.go @@ -0,0 +1,14 @@ +package db + +import ( + "fmt" + + "github.com/alist-org/alist/v3/internal/conf" +) + +func columnName(name string) string { + if conf.Conf.Database.Type == "postgres" { + return fmt.Sprintf(`"%s"`, name) + } + return fmt.Sprintf("`%s`", name) +} diff --git a/internal/driver/config.go b/internal/driver/config.go new file mode 100644 index 0000000000000000000000000000000000000000..6068143cb718ee94bcb097f588a18db0c8955c61 --- /dev/null +++ b/internal/driver/config.go @@ -0,0 +1,20 @@ +package driver + +type Config struct { + Name string `json:"name"` + LocalSort bool `json:"local_sort"` + OnlyLocal bool `json:"only_local"` + OnlyProxy bool `json:"only_proxy"` + NoCache bool `json:"no_cache"` + NoUpload bool `json:"no_upload"` + NeedMs bool `json:"need_ms"` // if need get message from user, such as validate code + DefaultRoot string `json:"default_root"` + CheckStatus bool `json:"-"` + Alert string `json:"alert"` //info,success,warning,danger + NoOverwriteUpload bool `json:"-"` // whether to support overwrite upload + ProxyRangeOption bool `json:"-"` +} + +func (c Config) MustProxy() bool { + return c.OnlyProxy || c.OnlyLocal +} diff --git a/internal/driver/driver.go b/internal/driver/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..a3bf37860f7b2429ff24ee998afbcf9caf6b9fe8 --- /dev/null +++ b/internal/driver/driver.go @@ -0,0 +1,136 @@ +package driver + +import ( + "context" + + "github.com/alist-org/alist/v3/internal/model" +) + +type Driver interface { + Meta + Reader + //Writer + //Other +} + +type Meta interface { + Config() Config + // GetStorage just get raw storage, no need to implement, because model.Storage have implemented + GetStorage() *model.Storage + SetStorage(model.Storage) + // GetAddition Additional is used for unmarshal of JSON, so need return pointer + GetAddition() Additional + // Init If already initialized, drop first + Init(ctx context.Context) error + Drop(ctx context.Context) error +} + +type Other interface { + Other(ctx context.Context, args model.OtherArgs) (interface{}, error) +} + +type Offline interface { + Offline(ctx context.Context, args model.OtherArgs) (interface{}, error) +} + +type Reader interface { + // List files in the path + // if identify files by path, need to set ID with path,like path.Join(dir.GetID(), obj.GetName()) + // if identify files by id, need to set ID with corresponding id + List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) + // Link get url/filepath/reader of file + Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) +} + +type GetRooter interface { + GetRoot(ctx context.Context) (model.Obj, error) +} + +type Getter interface { + // Get file by path, the path haven't been joined with root path + Get(ctx context.Context, path string) (model.Obj, error) +} + +//type Writer interface { +// Mkdir +// Move +// Rename +// Copy +// Remove +// Put +//} + +type Mkdir interface { + MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error +} + +type Move interface { + Move(ctx context.Context, srcObj, dstDir model.Obj) error +} + +type Rename interface { + Rename(ctx context.Context, srcObj model.Obj, newName string) error +} + +type Copy interface { + Copy(ctx context.Context, srcObj, dstDir model.Obj) error +} + +type Remove interface { + Remove(ctx context.Context, obj model.Obj) error +} + +type Put interface { + Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up UpdateProgress) error +} + +//type WriteResult interface { +// MkdirResult +// MoveResult +// RenameResult +// CopyResult +// PutResult +// Remove +//} + +type MkdirResult interface { + MakeDir(ctx context.Context, parentDir model.Obj, dirName string) (model.Obj, error) +} + +type MoveResult interface { + Move(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) +} + +type RenameResult interface { + Rename(ctx context.Context, srcObj model.Obj, newName string) (model.Obj, error) +} + +type CopyResult interface { + Copy(ctx context.Context, srcObj, dstDir model.Obj) (model.Obj, error) +} + +type PutResult interface { + Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up UpdateProgress) (model.Obj, error) +} + +type UpdateProgress func(percentage float64) + +type Progress struct { + Total int64 + Done int64 + up UpdateProgress +} + +func (p *Progress) Write(b []byte) (n int, err error) { + n = len(b) + p.Done += int64(n) + p.up(float64(p.Done) / float64(p.Total) * 100) + return +} + +func NewProgress(total int64, up UpdateProgress) *Progress { + return &Progress{ + Total: total, + up: up, + } +} diff --git a/internal/driver/item.go b/internal/driver/item.go new file mode 100644 index 0000000000000000000000000000000000000000..e8b0c8bf4cb1e95897fce2b59dbff2c5782999d5 --- /dev/null +++ b/internal/driver/item.go @@ -0,0 +1,48 @@ +package driver + +type Additional interface{} + +type Select string + +type Item struct { + Name string `json:"name"` + Type string `json:"type"` + Default string `json:"default"` + Options string `json:"options"` + Required bool `json:"required"` + Help string `json:"help"` +} + +type Info struct { + Common []Item `json:"common"` + Additional []Item `json:"additional"` + Config Config `json:"config"` +} + +type IRootPath interface { + GetRootPath() string +} + +type IRootId interface { + GetRootId() string +} + +type RootPath struct { + RootFolderPath string `json:"root_folder_path"` +} + +type RootID struct { + RootFolderID string `json:"root_folder_id"` +} + +func (r RootPath) GetRootPath() string { + return r.RootFolderPath +} + +func (r *RootPath) SetRootPath(path string) { + r.RootFolderPath = path +} + +func (r RootID) GetRootId() string { + return r.RootFolderID +} diff --git a/internal/errs/driver.go b/internal/errs/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..4b6b5cac48e22e10e7482ab7cdaff6aa3e5c1c1b --- /dev/null +++ b/internal/errs/driver.go @@ -0,0 +1,7 @@ +package errs + +import "errors" + +var ( + EmptyToken = errors.New("empty token") +) diff --git a/internal/errs/errors.go b/internal/errs/errors.go new file mode 100644 index 0000000000000000000000000000000000000000..6bac1a284802d0677ba9d2c37ffa651a80299355 --- /dev/null +++ b/internal/errs/errors.go @@ -0,0 +1,40 @@ +package errs + +import ( + "errors" + "fmt" + + pkgerr "github.com/pkg/errors" +) + +var ( + NotImplement = errors.New("not implement") + NotSupport = errors.New("not support") + RelativePath = errors.New("access using relative path is not allowed") + + MoveBetweenTwoStorages = errors.New("can't move files between two storages, try to copy") + UploadNotSupported = errors.New("upload not supported") + + MetaNotFound = errors.New("meta not found") + StorageNotFound = errors.New("storage not found") + StreamIncomplete = errors.New("upload/download stream incomplete, possible network issue") + StreamPeekFail = errors.New("StreamPeekFail") +) + +// NewErr wrap constant error with an extra message +// use errors.Is(err1, StorageNotFound) to check if err belongs to any internal error +func NewErr(err error, format string, a ...any) error { + return fmt.Errorf("%w; %s", err, fmt.Sprintf(format, a...)) +} + +func IsNotFoundError(err error) bool { + return errors.Is(pkgerr.Cause(err), ObjectNotFound) || errors.Is(pkgerr.Cause(err), StorageNotFound) +} + +func IsNotSupportError(err error) bool { + return errors.Is(pkgerr.Cause(err), NotSupport) +} + +func IsNotImplement(err error) bool { + return errors.Is(pkgerr.Cause(err), NotImplement) +} diff --git a/internal/errs/errors_test.go b/internal/errs/errors_test.go new file mode 100644 index 0000000000000000000000000000000000000000..3169a8772567083633484f9afc25673f4367f61e --- /dev/null +++ b/internal/errs/errors_test.go @@ -0,0 +1,27 @@ +package errs + +import ( + "errors" + pkgerr "github.com/pkg/errors" + "testing" +) + +func TestErrs(t *testing.T) { + + err1 := NewErr(StorageNotFound, "please add a storage first") + t.Logf("err1: %s", err1) + if !errors.Is(err1, StorageNotFound) { + t.Errorf("failed, expect %s is %s", err1, StorageNotFound) + } + if !errors.Is(pkgerr.Cause(err1), StorageNotFound) { + t.Errorf("failed, expect %s is %s", err1, StorageNotFound) + } + err2 := pkgerr.WithMessage(err1, "failed get storage") + t.Logf("err2: %s", err2) + if !errors.Is(err2, StorageNotFound) { + t.Errorf("failed, expect %s is %s", err2, StorageNotFound) + } + if !errors.Is(pkgerr.Cause(err2), StorageNotFound) { + t.Errorf("failed, expect %s is %s", err2, StorageNotFound) + } +} diff --git a/internal/errs/object.go b/internal/errs/object.go new file mode 100644 index 0000000000000000000000000000000000000000..00e8232ff95b5c7958ec585756487394c6c9ca46 --- /dev/null +++ b/internal/errs/object.go @@ -0,0 +1,17 @@ +package errs + +import ( + "errors" + + pkgerr "github.com/pkg/errors" +) + +var ( + ObjectNotFound = errors.New("object not found") + NotFolder = errors.New("not a folder") + NotFile = errors.New("not a file") +) + +func IsObjectNotFound(err error) bool { + return errors.Is(pkgerr.Cause(err), ObjectNotFound) +} diff --git a/internal/errs/operate.go b/internal/errs/operate.go new file mode 100644 index 0000000000000000000000000000000000000000..92fbd6a1a49dc208ea9d8ce4bb2b3cf4a0e49e59 --- /dev/null +++ b/internal/errs/operate.go @@ -0,0 +1,7 @@ +package errs + +import "errors" + +var ( + PermissionDenied = errors.New("permission denied") +) diff --git a/internal/errs/search.go b/internal/errs/search.go new file mode 100644 index 0000000000000000000000000000000000000000..9c864f4d2414f7bf5510bd5b9f8e3ca0c58faade --- /dev/null +++ b/internal/errs/search.go @@ -0,0 +1,7 @@ +package errs + +import "fmt" + +var ( + SearchNotAvailable = fmt.Errorf("search not available") +) diff --git a/internal/errs/user.go b/internal/errs/user.go new file mode 100644 index 0000000000000000000000000000000000000000..9e2d5b26cd4df57327e6013daa574a5d1a1fd765 --- /dev/null +++ b/internal/errs/user.go @@ -0,0 +1,10 @@ +package errs + +import "errors" + +var ( + EmptyUsername = errors.New("username is empty") + EmptyPassword = errors.New("password is empty") + WrongPassword = errors.New("password is incorrect") + DeleteAdminOrGuest = errors.New("cannot delete admin or guest") +) diff --git a/internal/fs/copy.go b/internal/fs/copy.go new file mode 100644 index 0000000000000000000000000000000000000000..5650bfd0f030db54b143318bc8f3f189c24ebda5 --- /dev/null +++ b/internal/fs/copy.go @@ -0,0 +1,234 @@ +package fs + +import ( + "context" + "fmt" + "net/http" + stdpath "path" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/pkg/tache" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +type CopyTask struct { + tache.Base + Status string `json:"-"` //don't save status to save space + SrcObjPath string `json:"src_path"` + DstDirPath string `json:"dst_path"` + Override bool `json:"override"` + srcStorage driver.Driver `json:"-"` + dstStorage driver.Driver `json:"-"` + SrcStorageMp string `json:"src_storage_mp"` + DstStorageMp string `json:"dst_storage_mp"` + Size int64 `json:"size"` +} + +func (t *CopyTask) GetName() string { + return fmt.Sprintf("copy [%s](%s) to [%s](%s)", t.SrcStorageMp, t.SrcObjPath, t.DstStorageMp, t.DstDirPath) +} + +func (t *CopyTask) GetStatus() string { + return t.Status +} + +func (t *CopyTask) SetSize(size int64) { + t.Size = size +} + +func (t *CopyTask) GetSize() int64 { + return t.Size +} + +func (t *CopyTask) OnFailed() { + result := fmt.Sprintf("%s:%s", t.GetName(), t.GetErr()) + log.Debug(result) + if setting.GetBool(conf.NotifyEnabled) && setting.GetBool(conf.NotifyOnCopyFailed) { + go op.Notify("文件复制结果", result) + } +} + +func (t *CopyTask) OnSucceeded() { + result := fmt.Sprintf("复制%s到%s成功", t.SrcObjPath, t.DstDirPath) + log.Debug(result) + if setting.GetBool(conf.NotifyEnabled) && setting.GetBool(conf.NotifyOnCopySucceeded) { + go op.Notify("文件复制结果", result) + } +} + +func humanReadableSize(size int64) string { + const unit = 1024 + if size < unit { + return fmt.Sprintf("%d B", size) + } + div, exp := int64(unit), 0 + for n := size / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(size)/float64(div), "KMGTPE"[exp]) +} + +func (t *CopyTask) Run() error { + + var err error + if t.srcStorage == nil { + t.srcStorage, err = op.GetStorageByMountPath(t.SrcStorageMp) + } + if t.dstStorage == nil { + t.dstStorage, err = op.GetStorageByMountPath(t.DstStorageMp) + } + if err != nil { + return errors.WithMessage(err, "failed get storage") + } + + if !t.Override { + srcObj, err := get(context.Background(), t.SrcStorageMp+t.SrcObjPath) + if err != nil { + return errors.WithMessagef(err, "failed get src [%s] file", t.SrcObjPath) + } + if srcObj.IsDir() { + return copyBetween2Storages(t, t.srcStorage, t.dstStorage, t.SrcObjPath, t.DstDirPath) + } + var distSize int64 + t.Size = srcObj.GetSize() + dst_path := stdpath.Join(t.DstStorageMp+t.DstDirPath, srcObj.GetName()) + obj, err := get(context.Background(), dst_path) + if err == nil { + distSize = obj.GetSize() + } + if err != nil || distSize != t.Size { + //文件不存在或者大小不一样,直接复制 + return copyBetween2Storages(t, t.srcStorage, t.dstStorage, t.SrcObjPath, t.DstDirPath) + } else { + //文件已经存在,直接返回完成 + return errors.WithMessage(err, obj.GetName()+"文件已经存在") + } + } else { + return copyBetween2Storages(t, t.srcStorage, t.dstStorage, t.SrcObjPath, t.DstDirPath) + } + +} + +var CopyTaskManager *tache.Manager[*CopyTask] + +// Copy if in the same storage, call move method +// if not, add copy task +func _copy(ctx context.Context, SrcObjPath, DstDirPath string, overwrite bool, lazyCache ...bool) (tache.TaskWithInfo, error) { + srcStorage, srcObjActualPath, err := op.GetStorageAndActualPath(SrcObjPath) + if err != nil { + return nil, errors.WithMessage(err, "failed get src storage") + } + dstStorage, dstDirActualPath, err := op.GetStorageAndActualPath(DstDirPath) + if err != nil { + return nil, errors.WithMessage(err, "failed get dst storage") + } + + // copy if in the same storage, just call driver.Copy + if srcStorage.GetStorage() == dstStorage.GetStorage() { + return nil, op.Copy(ctx, srcStorage, srcObjActualPath, dstDirActualPath, lazyCache...) + } + if ctx.Value(conf.NoTaskKey) != nil { + srcObj, err := op.Get(ctx, srcStorage, srcObjActualPath) + if err != nil { + return nil, errors.WithMessagef(err, "failed get src [%s] file", SrcObjPath) + } + if !srcObj.IsDir() { + // copy file directly + link, _, err := op.Link(ctx, srcStorage, srcObjActualPath, model.LinkArgs{ + Header: http.Header{}, + }) + if err != nil { + return nil, errors.WithMessagef(err, "failed get [%s] link", SrcObjPath) + } + fs := stream.FileStream{ + Obj: srcObj, + Ctx: ctx, + } + // any link provided is seekable + ss, err := stream.NewSeekableStream(fs, link) + if err != nil { + return nil, errors.WithMessagef(err, "failed get [%s] stream", SrcObjPath) + } + return nil, op.Put(ctx, dstStorage, dstDirActualPath, ss, nil, false) + } + } + // not in the same storage + + t := &CopyTask{ + srcStorage: srcStorage, + dstStorage: dstStorage, + SrcObjPath: srcObjActualPath, + DstDirPath: dstDirActualPath, + Override: overwrite, + SrcStorageMp: srcStorage.GetStorage().MountPath, + DstStorageMp: dstStorage.GetStorage().MountPath, + } + CopyTaskManager.Add(t) + return t, nil +} + +func copyBetween2Storages(t *CopyTask, srcStorage, dstStorage driver.Driver, SrcObjPath, DstDirPath string) error { + t.Status = "getting src object" + srcObj, err := op.Get(t.Ctx(), srcStorage, SrcObjPath) + if err != nil { + return errors.WithMessagef(err, "failed get src [%s] file", SrcObjPath) + } + if srcObj.IsDir() { + t.Status = "src object is dir, listing objs" + objs, err := op.List(t.Ctx(), srcStorage, SrcObjPath, model.ListArgs{}) + if err != nil { + return errors.WithMessagef(err, "failed list src [%s] objs", SrcObjPath) + } + for _, obj := range objs { + if utils.IsCanceled(t.Ctx()) { + return nil + } + SrcObjPath := stdpath.Join(SrcObjPath, obj.GetName()) + dstObjPath := stdpath.Join(DstDirPath, srcObj.GetName()) + CopyTaskManager.Add(&CopyTask{ + srcStorage: srcStorage, + dstStorage: dstStorage, + SrcObjPath: SrcObjPath, + DstDirPath: dstObjPath, + Override: t.Override, + SrcStorageMp: srcStorage.GetStorage().MountPath, + DstStorageMp: dstStorage.GetStorage().MountPath, + }) + } + t.Status = "src object is dir, added all copy tasks of objs" + return nil + } + return copyFileBetween2Storages(t, srcStorage, dstStorage, SrcObjPath, DstDirPath) +} + +func copyFileBetween2Storages(tsk *CopyTask, srcStorage, dstStorage driver.Driver, srcFilePath, DstDirPath string) error { + tsk.Status = fmt.Sprintf("getting src object (%s)", humanReadableSize(tsk.Size)) + srcFile, err := op.Get(tsk.Ctx(), srcStorage, srcFilePath) + if err != nil { + return errors.WithMessagef(err, "failed get src [%s] file", srcFilePath) + } + link, _, err := op.Link(tsk.Ctx(), srcStorage, srcFilePath, model.LinkArgs{ + Header: http.Header{}, + }) + if err != nil { + return errors.WithMessagef(err, "failed get [%s] link", srcFilePath) + } + fs := stream.FileStream{ + Obj: srcFile, + Ctx: tsk.Ctx(), + } + // any link provided is seekable + ss, err := stream.NewSeekableStream(fs, link) + if err != nil { + return errors.WithMessagef(err, "failed get [%s] stream", srcFilePath) + } + return op.Put(tsk.Ctx(), dstStorage, DstDirPath, ss, tsk.SetProgress, true) +} diff --git a/internal/fs/fs.go b/internal/fs/fs.go new file mode 100644 index 0000000000000000000000000000000000000000..09d87dce4c27b662c76cf6fb280b5b5160a9f1a4 --- /dev/null +++ b/internal/fs/fs.go @@ -0,0 +1,130 @@ +package fs + +import ( + "context" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/tache" + log "github.com/sirupsen/logrus" +) + +// the param named path of functions in this package is a mount path +// So, the purpose of this package is to convert mount path to actual path +// then pass the actual path to the op package + +type ListArgs struct { + Refresh bool + NoLog bool +} + +func List(ctx context.Context, path string, args *ListArgs) ([]model.Obj, error) { + res, err := list(ctx, path, args) + if err != nil { + if !args.NoLog { + log.Errorf("failed list %s: %+v", path, err) + } + return nil, err + } + return res, nil +} + +type GetArgs struct { + NoLog bool +} + +func Get(ctx context.Context, path string, args *GetArgs) (model.Obj, error) { + res, err := get(ctx, path) + if err != nil { + if !args.NoLog { + log.Warnf("failed get %s: %s", path, err) + } + return nil, err + } + return res, nil +} + +func Link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, model.Obj, error) { + res, file, err := link(ctx, path, args) + if err != nil { + log.Errorf("failed link %s: %+v", path, err) + return nil, nil, err + } + return res, file, nil +} + +func MakeDir(ctx context.Context, path string, lazyCache ...bool) error { + err := makeDir(ctx, path, lazyCache...) + if err != nil { + log.Errorf("failed make dir %s: %+v", path, err) + } + return err +} + +func Move(ctx context.Context, srcPath, dstDirPath string, lazyCache ...bool) error { + err := move(ctx, srcPath, dstDirPath, lazyCache...) + if err != nil { + log.Errorf("failed move %s to %s: %+v", srcPath, dstDirPath, err) + } + return err +} + +func Copy(ctx context.Context, srcObjPath, dstDirPath string, overwrite bool, lazyCache ...bool) (tache.TaskWithInfo, error) { + res, err := _copy(ctx, srcObjPath, dstDirPath, overwrite, lazyCache...) + if err != nil { + log.Errorf("failed copy %s to %s: %+v", srcObjPath, dstDirPath, err) + } + return res, err +} + +func Rename(ctx context.Context, srcPath, dstName string, lazyCache ...bool) error { + err := rename(ctx, srcPath, dstName, lazyCache...) + if err != nil { + log.Errorf("failed rename %s to %s: %+v", srcPath, dstName, err) + } + return err +} + +func Remove(ctx context.Context, path string) error { + err := remove(ctx, path) + if err != nil { + log.Errorf("failed remove %s: %+v", path, err) + } + return err +} + +func PutDirectly(ctx context.Context, dstDirPath string, file model.FileStreamer, lazyCache ...bool) error { + err := putDirectly(ctx, dstDirPath, file, lazyCache...) + if err != nil { + log.Errorf("failed put %s: %+v", dstDirPath, err) + } + return err +} + +func PutAsTask(dstDirPath string, file model.FileStreamer) (tache.TaskWithInfo, error) { + t, err := putAsTask(dstDirPath, file) + if err != nil { + log.Errorf("failed put %s: %+v", dstDirPath, err) + } + return t, err +} + +type GetStoragesArgs struct { +} + +func GetStorage(path string, args *GetStoragesArgs) (driver.Driver, error) { + storageDriver, _, err := op.GetStorageAndActualPath(path) + if err != nil { + return nil, err + } + return storageDriver, nil +} + +func Other(ctx context.Context, args model.FsOtherArgs) (interface{}, error) { + res, err := other(ctx, args) + if err != nil { + log.Errorf("failed remove %s: %+v", args.Path, err) + } + return res, err +} diff --git a/internal/fs/get.go b/internal/fs/get.go new file mode 100644 index 0000000000000000000000000000000000000000..17c202b7412002e785c3978a268f180b4a1b1a9d --- /dev/null +++ b/internal/fs/get.go @@ -0,0 +1,39 @@ +package fs + +import ( + "context" + stdpath "path" + "time" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" +) + +func get(ctx context.Context, path string) (model.Obj, error) { + path = utils.FixAndCleanPath(path) + // maybe a virtual file + if path != "/" { + virtualFiles := op.GetStorageVirtualFilesByPath(stdpath.Dir(path)) + for _, f := range virtualFiles { + if f.GetName() == stdpath.Base(path) { + return f, nil + } + } + } + storage, actualPath, err := op.GetStorageAndActualPath(path) + if err != nil { + // if there are no storage prefix with path, maybe root folder + if path == "/" { + return &model.Object{ + Name: "root", + Size: 0, + Modified: time.Time{}, + IsFolder: true, + }, nil + } + return nil, errors.WithMessage(err, "failed get storage") + } + return op.Get(ctx, storage, actualPath) +} diff --git a/internal/fs/link.go b/internal/fs/link.go new file mode 100644 index 0000000000000000000000000000000000000000..3dfd7e5a35eb3ce682c8d7abc10aac825a73f8de --- /dev/null +++ b/internal/fs/link.go @@ -0,0 +1,29 @@ +package fs + +import ( + "context" + "strings" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +func link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, model.Obj, error) { + storage, actualPath, err := op.GetStorageAndActualPath(path) + if err != nil { + return nil, nil, errors.WithMessage(err, "failed get storage") + } + l, obj, err := op.Link(ctx, storage, actualPath, args) + if err != nil { + return nil, nil, errors.WithMessage(err, "failed link") + } + if l.URL != "" && !strings.HasPrefix(l.URL, "http://") && !strings.HasPrefix(l.URL, "https://") { + if c, ok := ctx.(*gin.Context); ok { + l.URL = common.GetApiUrl(c.Request) + l.URL + } + } + return l, obj, nil +} diff --git a/internal/fs/list.go b/internal/fs/list.go new file mode 100644 index 0000000000000000000000000000000000000000..6e257cea6fa959bb35aa8e2963ead001044d16ce --- /dev/null +++ b/internal/fs/list.go @@ -0,0 +1,65 @@ +package fs + +import ( + "context" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +// List files +func list(ctx context.Context, path string, args *ListArgs) ([]model.Obj, error) { + meta, _ := ctx.Value("meta").(*model.Meta) + user, _ := ctx.Value("user").(*model.User) + virtualFiles := op.GetStorageVirtualFilesByPath(path) + storage, actualPath, err := op.GetStorageAndActualPath(path) + if err != nil && len(virtualFiles) == 0 { + return nil, errors.WithMessage(err, "failed get storage") + } + + var _objs []model.Obj + if storage != nil { + _objs, err = op.List(ctx, storage, actualPath, model.ListArgs{ + ReqPath: path, + }, args.Refresh) + if err != nil { + if !args.NoLog { + log.Errorf("fs/list: %+v", err) + } + if len(virtualFiles) == 0 { + return nil, errors.WithMessage(err, "failed get objs") + } + } + } + + om := model.NewObjMerge() + if whetherHide(user, meta, path) { + om.InitHideReg(meta.Hide) + } + objs := om.Merge(_objs, virtualFiles...) + return objs, nil +} + +func whetherHide(user *model.User, meta *model.Meta, path string) bool { + // if is admin, don't hide + if user == nil || user.CanSeeHides() { + return false + } + // if meta is nil, don't hide + if meta == nil { + return false + } + // if meta.Hide is empty, don't hide + if meta.Hide == "" { + return false + } + // if meta doesn't apply to sub_folder, don't hide + if !utils.PathEqual(meta.Path, path) && !meta.HSub { + return false + } + // if is guest, hide + return true +} diff --git a/internal/fs/other.go b/internal/fs/other.go new file mode 100644 index 0000000000000000000000000000000000000000..85b7b1d17bfbf74001671079528b02a61a356bf8 --- /dev/null +++ b/internal/fs/other.go @@ -0,0 +1,58 @@ +package fs + +import ( + "context" + + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/pkg/errors" +) + +func makeDir(ctx context.Context, path string, lazyCache ...bool) error { + storage, actualPath, err := op.GetStorageAndActualPath(path) + if err != nil { + return errors.WithMessage(err, "failed get storage") + } + return op.MakeDir(ctx, storage, actualPath, lazyCache...) +} + +func move(ctx context.Context, srcPath, dstDirPath string, lazyCache ...bool) error { + srcStorage, srcActualPath, err := op.GetStorageAndActualPath(srcPath) + if err != nil { + return errors.WithMessage(err, "failed get src storage") + } + dstStorage, dstDirActualPath, err := op.GetStorageAndActualPath(dstDirPath) + if err != nil { + return errors.WithMessage(err, "failed get dst storage") + } + if srcStorage.GetStorage() != dstStorage.GetStorage() { + return errors.WithStack(errs.MoveBetweenTwoStorages) + } + return op.Move(ctx, srcStorage, srcActualPath, dstDirActualPath, lazyCache...) +} + +func rename(ctx context.Context, srcPath, dstName string, lazyCache ...bool) error { + storage, srcActualPath, err := op.GetStorageAndActualPath(srcPath) + if err != nil { + return errors.WithMessage(err, "failed get storage") + } + return op.Rename(ctx, storage, srcActualPath, dstName, lazyCache...) +} + +func remove(ctx context.Context, path string) error { + storage, actualPath, err := op.GetStorageAndActualPath(path) + if err != nil { + return errors.WithMessage(err, "failed get storage") + } + return op.Remove(ctx, storage, actualPath) +} + +func other(ctx context.Context, args model.FsOtherArgs) (interface{}, error) { + storage, actualPath, err := op.GetStorageAndActualPath(args.Path) + if err != nil { + return nil, errors.WithMessage(err, "failed get storage") + } + args.Path = actualPath + return op.Other(ctx, storage, args) +} diff --git a/internal/fs/put.go b/internal/fs/put.go new file mode 100644 index 0000000000000000000000000000000000000000..1ba7c1462991e12e2fd28c329340e38a6637aaf7 --- /dev/null +++ b/internal/fs/put.go @@ -0,0 +1,90 @@ +package fs + +import ( + "context" + "fmt" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/tache" + "github.com/pkg/errors" +) + +type UploadTask struct { + tache.Base + Name string `json:"name"` + Status string `json:"status"` + storage driver.Driver + dstDirActualPath string + file model.FileStreamer +} + +// func (t *UploadTask) OnFailed() { +// result := fmt.Sprintf("%s上传失败:%s", t.file.GetName(), t.GetErr()) +// log.Debug(result) +// go op.Notify("文件上传结果", result) +// } + +// func (t *UploadTask) OnSucceeded() { +// result := fmt.Sprintf("%s上传成功", t.file.GetName()) +// log.Debug(result) +// go op.Notify("文件上传结果", "文件复制成功") +// } + +func (t *UploadTask) GetName() string { + return t.Name + //return fmt.Sprintf("upload %s to [%s](%s)", t.file.GetName(), t.storage.GetStorage().MountPath, t.dstDirActualPath) +} + +func (t *UploadTask) GetStatus() string { + return t.Status + //return "uploading" +} + +func (t *UploadTask) Run() error { + return op.Put(t.Ctx(), t.storage, t.dstDirActualPath, t.file, t.SetProgress, true) +} + +var UploadTaskManager *tache.Manager[*UploadTask] + +// putAsTask add as a put task and return immediately +func putAsTask(dstDirPath string, file model.FileStreamer) (tache.TaskWithInfo, error) { + storage, dstDirActualPath, err := op.GetStorageAndActualPath(dstDirPath) + if err != nil { + return nil, errors.WithMessage(err, "failed get storage") + } + if storage.Config().NoUpload { + return nil, errors.WithStack(errs.UploadNotSupported) + } + if file.NeedStore() { + _, err := file.CacheFullInTempFile() + if err != nil { + return nil, errors.Wrapf(err, "failed to create temp file") + } + //file.SetReader(tempFile) + //file.SetTmpFile(tempFile) + } + t := &UploadTask{ + Name: fmt.Sprintf("upload %s to [%s](%s)", file.GetName(), storage.GetStorage().MountPath, dstDirActualPath), + Status: "uploading", + storage: storage, + dstDirActualPath: dstDirActualPath, + file: file, + } + UploadTaskManager.Add(t) + return t, nil +} + +// putDirect put the file and return after finish +func putDirectly(ctx context.Context, dstDirPath string, file model.FileStreamer, lazyCache ...bool) error { + storage, dstDirActualPath, err := op.GetStorageAndActualPath(dstDirPath) + if err != nil { + return errors.WithMessage(err, "failed get storage") + } + if storage.Config().NoUpload { + return errors.WithStack(errs.UploadNotSupported) + } + return op.Put(ctx, storage, dstDirActualPath, file, nil, lazyCache...) +} diff --git a/internal/fs/walk.go b/internal/fs/walk.go new file mode 100644 index 0000000000000000000000000000000000000000..9e1eef76d788d87260855fbaaf65ed92a7c4453f --- /dev/null +++ b/internal/fs/walk.go @@ -0,0 +1,45 @@ +package fs + +import ( + "context" + "path" + "path/filepath" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" +) + +// WalkFS traverses filesystem fs starting at name up to depth levels. +// +// WalkFS will stop when current depth > `depth`. For each visited node, +// WalkFS calls walkFn. If a visited file system node is a directory and +// walkFn returns path.SkipDir, walkFS will skip traversal of this node. +func WalkFS(ctx context.Context, depth int, name string, info model.Obj, walkFn func(reqPath string, info model.Obj) error) error { + // This implementation is based on Walk's code in the standard path/path package. + walkFnErr := walkFn(name, info) + if walkFnErr != nil { + if info.IsDir() && walkFnErr == filepath.SkipDir { + return nil + } + return walkFnErr + } + if !info.IsDir() || depth == 0 { + return nil + } + meta, _ := op.GetNearestMeta(name) + // Read directory names. + objs, err := List(context.WithValue(ctx, "meta", meta), name, &ListArgs{}) + if err != nil { + return walkFnErr + } + for _, fileInfo := range objs { + filename := path.Join(name, fileInfo.GetName()) + if err := WalkFS(ctx, depth-1, filename, fileInfo, walkFn); err != nil { + if err == filepath.SkipDir { + break + } + return err + } + } + return nil +} diff --git a/internal/fuse/fs.go b/internal/fuse/fs.go new file mode 100644 index 0000000000000000000000000000000000000000..7783b169fb0f17a837334c6f0099688bd31fcd10 --- /dev/null +++ b/internal/fuse/fs.go @@ -0,0 +1,170 @@ +package fuse + +import "github.com/winfsp/cgofuse/fuse" + +type Fs struct { + RootFolder string + fuse.FileSystemBase +} + +func (fs *Fs) Init() { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Destroy() { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Statfs(path string, stat *fuse.Statfs_t) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Mknod(path string, mode uint32, dev uint64) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Mkdir(path string, mode uint32) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Unlink(path string) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Rmdir(path string) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Link(oldpath string, newpath string) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Symlink(target string, newpath string) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Readlink(path string) (int, string) { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Rename(oldpath string, newpath string) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Chmod(path string, mode uint32) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Chown(path string, uid uint32, gid uint32) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Utimens(path string, tmsp []fuse.Timespec) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Access(path string, mask uint32) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Create(path string, flags int, mode uint32) (int, uint64) { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Open(path string, flags int) (int, uint64) { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Getattr(path string, stat *fuse.Stat_t, fh uint64) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Truncate(path string, size int64, fh uint64) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Read(path string, buff []byte, ofst int64, fh uint64) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Write(path string, buff []byte, ofst int64, fh uint64) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Flush(path string, fh uint64) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Release(path string, fh uint64) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Fsync(path string, datasync bool, fh uint64) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Opendir(path string) (int, uint64) { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Readdir(path string, fill func(name string, stat *fuse.Stat_t, ofst int64) bool, ofst int64, fh uint64) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Releasedir(path string, fh uint64) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Fsyncdir(path string, datasync bool, fh uint64) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Setxattr(path string, name string, value []byte, flags int) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Getxattr(path string, name string) (int, []byte) { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Removexattr(path string, name string) int { + //TODO implement me + panic("implement me") +} + +func (fs *Fs) Listxattr(path string, fill func(name string) bool) int { + //TODO implement me + panic("implement me") +} + +var _ fuse.FileSystemInterface = (*Fs)(nil) diff --git a/internal/fuse/mount.go b/internal/fuse/mount.go new file mode 100644 index 0000000000000000000000000000000000000000..30e3d54647a0ba02f713a112e3cba031715a3451 --- /dev/null +++ b/internal/fuse/mount.go @@ -0,0 +1,9 @@ +package fuse + +import "github.com/winfsp/cgofuse/fuse" + +func Mount(mountSrc, mountDst string, opts []string) { + fs := &Fs{RootFolder: mountSrc} + host := fuse.NewFileSystemHost(fs) + go host.Mount(mountDst, opts) +} diff --git a/internal/message/http.go b/internal/message/http.go new file mode 100644 index 0000000000000000000000000000000000000000..3f023dec5e1de2ed1ceed5314a110ead8ced1d40 --- /dev/null +++ b/internal/message/http.go @@ -0,0 +1,82 @@ +package message + +import ( + "time" + + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +type Http struct { + Received chan string // received messages from web + ToSend chan Message // messages to send to web +} + +type Req struct { + Message string `json:"message" form:"message"` +} + +func (p *Http) GetHandle(c *gin.Context) { + select { + case message := <-p.ToSend: + common.SuccessResp(c, message) + default: + common.ErrorStrResp(c, "no message", 404) + } +} + +func (p *Http) SendHandle(c *gin.Context) { + var req Req + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + select { + case p.Received <- req.Message: + common.SuccessResp(c) + default: + common.ErrorStrResp(c, "nowhere needed", 500) + } +} + +func (p *Http) Send(message Message) error { + select { + case p.ToSend <- message: + return nil + default: + return errors.New("send failed") + } +} + +func (p *Http) Receive() (string, error) { + select { + case message := <-p.Received: + return message, nil + default: + return "", errors.New("receive failed") + } +} + +func (p *Http) WaitSend(message Message, d int) error { + select { + case p.ToSend <- message: + return nil + case <-time.After(time.Duration(d) * time.Second): + return errors.New("send timeout") + } +} + +func (p *Http) WaitReceive(d int) (string, error) { + select { + case message := <-p.Received: + return message, nil + case <-time.After(time.Duration(d) * time.Second): + return "", errors.New("receive timeout") + } +} + +var HttpInstance = &Http{ + Received: make(chan string), + ToSend: make(chan Message), +} diff --git a/internal/message/message.go b/internal/message/message.go new file mode 100644 index 0000000000000000000000000000000000000000..0ca0f2ae5be685000c926a46bc9bb5b90062ccaa --- /dev/null +++ b/internal/message/message.go @@ -0,0 +1,17 @@ +package message + +type Message struct { + Type string `json:"type"` + Content interface{} `json:"content"` +} + +type Messenger interface { + Send(Message) error + Receive() (string, error) + WaitSend(Message, int) error + WaitReceive(int) (string, error) +} + +func GetMessenger() Messenger { + return HttpInstance +} diff --git a/internal/message/ws.go b/internal/message/ws.go new file mode 100644 index 0000000000000000000000000000000000000000..725b71d3297981746e9dcc4b0b6f1978ac3a2542 --- /dev/null +++ b/internal/message/ws.go @@ -0,0 +1,3 @@ +package message + +// TODO websocket implementation diff --git a/internal/model/args.go b/internal/model/args.go new file mode 100644 index 0000000000000000000000000000000000000000..613699b95b4ba121e02e4522a540675af18607e9 --- /dev/null +++ b/internal/model/args.go @@ -0,0 +1,70 @@ +package model + +import ( + "context" + "io" + "net/http" + "time" + + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" +) + +type ListArgs struct { + ReqPath string + S3ShowPlaceholder bool + Refresh bool +} + +type LinkArgs struct { + IP string + Header http.Header + Type string + HttpReq *http.Request +} + +type Link struct { + URL string `json:"url"` // most common way + Header http.Header `json:"header"` // needed header (for url) + RangeReadCloser RangeReadCloserIF `json:"-"` // recommended way if can't use URL + MFile File `json:"-"` // best for local,smb... file system, which exposes MFile + + Expiration *time.Duration // local cache expire Duration + IPCacheKey bool `json:"-"` // add ip to cache key + + //for accelerating request, use multi-thread downloading + Concurrency int `json:"concurrency"` + PartSize int `json:"part_size"` +} + +type OtherArgs struct { + Obj Obj + Method string + Data interface{} +} + +type FsOtherArgs struct { + Path string `json:"path" form:"path"` + Method string `json:"method" form:"method"` + Data interface{} `json:"data" form:"data"` +} +type RangeReadCloserIF interface { + RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) + utils.ClosersIF +} + +var _ RangeReadCloserIF = (*RangeReadCloser)(nil) + +type RangeReadCloser struct { + RangeReader RangeReaderFunc + utils.Closers +} + +func (r RangeReadCloser) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + rc, err := r.RangeReader(ctx, httpRange) + r.Closers.Add(rc) + return rc, err +} + +// type WriterFunc func(w io.Writer) error +type RangeReaderFunc func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) diff --git a/internal/model/file.go b/internal/model/file.go new file mode 100644 index 0000000000000000000000000000000000000000..ba65ef938dbfff4f6056e40297c9c501fb6131ca --- /dev/null +++ b/internal/model/file.go @@ -0,0 +1,25 @@ +package model + +import "io" + +// File is basic file level accessing interface +type File interface { + io.Reader + io.ReaderAt + io.Seeker + io.Closer +} + +type NopMFileIF interface { + io.Reader + io.ReaderAt + io.Seeker +} +type NopMFile struct { + NopMFileIF +} + +func (NopMFile) Close() error { return nil } +func NewNopMFile(r NopMFileIF) File { + return NopMFile{r} +} diff --git a/internal/model/meta.go b/internal/model/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..0446137a2c05a8cea8024f472554aedfef5b98c8 --- /dev/null +++ b/internal/model/meta.go @@ -0,0 +1,16 @@ +package model + +type Meta struct { + ID uint `json:"id" gorm:"primaryKey"` + Path string `json:"path" gorm:"unique" binding:"required"` + Password string `json:"password"` + PSub bool `json:"p_sub"` + Write bool `json:"write"` + WSub bool `json:"w_sub"` + Hide string `json:"hide"` + HSub bool `json:"h_sub"` + Readme string `json:"readme"` + RSub bool `json:"r_sub"` + Header string `json:"header"` + HeaderSub bool `json:"header_sub"` +} diff --git a/internal/model/notify.go b/internal/model/notify.go new file mode 100644 index 0000000000000000000000000000000000000000..33cd4b3d026a6bca74a247ee57043cf8446b1dae --- /dev/null +++ b/internal/model/notify.go @@ -0,0 +1,18 @@ +package model + +type Bark struct { + BarkPush string `json:"barkPush"` + BarkIcon string `json:"barkIcon,omitempty"` // 可选字段 + BarkSound string `json:"barkSound,omitempty"` // 可选字段 + BarkGroup string `json:"barkGroup,omitempty"` // 可选字段 + BarkLevel string `json:"barkLevel,omitempty"` // 可选字段 + BarkUrl string `json:"barkUrl,omitempty"` // 可选字段 +} + +type Webhook struct { + WebhookUrl string `json:"webhookUrl"` + WebhookBody string `json:"webhookBody,omitempty"` // 可选字段 + WebhookHeaders string `json:"webhookHeaders,omitempty"` // 可选字段 + WebhookMethod string `json:"webhookMethod"` // 可选字段 + WebhookContentType string `json:"webhookContentType"` // 可选字段 +} diff --git a/internal/model/obj.go b/internal/model/obj.go new file mode 100644 index 0000000000000000000000000000000000000000..122fb546278782ec099a715a58033a817b25028d --- /dev/null +++ b/internal/model/obj.go @@ -0,0 +1,211 @@ +package model + +import ( + "io" + "sort" + "strings" + "time" + + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/dlclark/regexp2" + + mapset "github.com/deckarep/golang-set/v2" + + "github.com/maruel/natural" +) + +type ObjUnwrap interface { + Unwrap() Obj +} + +type Obj interface { + GetSize() int64 + GetName() string + ModTime() time.Time + CreateTime() time.Time + IsDir() bool + GetHash() utils.HashInfo + + // The internal information of the driver. + // If you want to use it, please understand what it means + GetID() string + GetPath() string +} + +// FileStreamer ->check FileStream for more comments +type FileStreamer interface { + io.Reader + io.Closer + Obj + GetMimetype() string + //SetReader(io.Reader) + NeedStore() bool + IsForceStreamUpload() bool + GetExist() Obj + SetExist(Obj) + //for a non-seekable Stream, RangeRead supports peeking some data, and CacheFullInTempFile still works + RangeRead(http_range.Range) (io.Reader, error) + //for a non-seekable Stream, if Read is called, this function won't work + CacheFullInTempFile() (File, error) +} + +type URL interface { + URL() string +} + +type Thumb interface { + Thumb() string +} + +type SetPath interface { + SetPath(path string) +} + +func SortFiles(objs []Obj, orderBy, orderDirection string) { + if orderBy == "" { + return + } + sort.Slice(objs, func(i, j int) bool { + switch orderBy { + case "name": + { + c := natural.Less(objs[i].GetName(), objs[j].GetName()) + if orderDirection == "desc" { + return !c + } + return c + } + case "size": + { + if orderDirection == "desc" { + return objs[i].GetSize() >= objs[j].GetSize() + } + return objs[i].GetSize() <= objs[j].GetSize() + } + case "modified": + if orderDirection == "desc" { + return objs[i].ModTime().After(objs[j].ModTime()) + } + return objs[i].ModTime().Before(objs[j].ModTime()) + } + return false + }) +} + +func ExtractFolder(objs []Obj, extractFolder string) { + if extractFolder == "" { + return + } + front := extractFolder == "front" + sort.SliceStable(objs, func(i, j int) bool { + if objs[i].IsDir() || objs[j].IsDir() { + if !objs[i].IsDir() { + return !front + } + if !objs[j].IsDir() { + return front + } + } + return false + }) +} + +func WrapObjName(objs Obj) Obj { + return &ObjWrapName{Obj: objs} +} + +func WrapObjsName(objs []Obj) { + for i := 0; i < len(objs); i++ { + objs[i] = &ObjWrapName{Obj: objs[i]} + } +} + +func UnwrapObj(obj Obj) Obj { + if unwrap, ok := obj.(ObjUnwrap); ok { + obj = unwrap.Unwrap() + } + return obj +} + +func GetThumb(obj Obj) (thumb string, ok bool) { + if obj, ok := obj.(Thumb); ok { + return obj.Thumb(), true + } + if unwrap, ok := obj.(ObjUnwrap); ok { + return GetThumb(unwrap.Unwrap()) + } + return thumb, false +} + +func GetUrl(obj Obj) (url string, ok bool) { + if obj, ok := obj.(URL); ok { + return obj.URL(), true + } + if unwrap, ok := obj.(ObjUnwrap); ok { + return GetUrl(unwrap.Unwrap()) + } + return url, false +} + +func GetRawObject(obj Obj) *Object { + switch v := obj.(type) { + case *ObjThumbURL: + return &v.Object + case *ObjThumb: + return &v.Object + case *ObjectURL: + return &v.Object + case *Object: + return v + } + return nil +} + +// Merge +func NewObjMerge() *ObjMerge { + return &ObjMerge{ + set: mapset.NewSet[string](), + } +} + +type ObjMerge struct { + regs []*regexp2.Regexp + set mapset.Set[string] +} + +func (om *ObjMerge) Merge(objs []Obj, objs_ ...Obj) []Obj { + newObjs := make([]Obj, 0, len(objs)+len(objs_)) + newObjs = om.insertObjs(om.insertObjs(newObjs, objs...), objs_...) + return newObjs +} + +func (om *ObjMerge) insertObjs(objs []Obj, objs_ ...Obj) []Obj { + for _, obj := range objs_ { + if om.clickObj(obj) { + objs = append(objs, obj) + } + } + return objs +} + +func (om *ObjMerge) clickObj(obj Obj) bool { + for _, reg := range om.regs { + if isMatch, _ := reg.MatchString(obj.GetName()); isMatch { + return false + } + } + return om.set.Add(obj.GetName()) +} + +func (om *ObjMerge) InitHideReg(hides string) { + rs := strings.Split(hides, "\n") + om.regs = make([]*regexp2.Regexp, 0, len(rs)) + for _, r := range rs { + om.regs = append(om.regs, regexp2.MustCompile(r, regexp2.None)) + } +} + +func (om *ObjMerge) Reset() { + om.set.Clear() +} diff --git a/internal/model/object.go b/internal/model/object.go new file mode 100644 index 0000000000000000000000000000000000000000..93f2c307a03c141386a5fea4b7b3224bf69d8ab7 --- /dev/null +++ b/internal/model/object.go @@ -0,0 +1,104 @@ +package model + +import ( + "time" + + "github.com/alist-org/alist/v3/pkg/utils" +) + +type ObjWrapName struct { + Name string + Obj +} + +func (o *ObjWrapName) Unwrap() Obj { + return o.Obj +} + +func (o *ObjWrapName) GetName() string { + if o.Name == "" { + o.Name = utils.MappingName(o.Obj.GetName()) + } + return o.Name +} + +type Object struct { + ID string + Path string + Name string + Size int64 + Modified time.Time + Ctime time.Time // file create time + IsFolder bool + HashInfo utils.HashInfo +} + +func (o *Object) GetName() string { + return o.Name +} + +func (o *Object) GetSize() int64 { + return o.Size +} + +func (o *Object) ModTime() time.Time { + return o.Modified +} +func (o *Object) CreateTime() time.Time { + if o.Ctime.IsZero() { + return o.ModTime() + } + return o.Ctime +} + +func (o *Object) IsDir() bool { + return o.IsFolder +} + +func (o *Object) GetID() string { + return o.ID +} + +func (o *Object) GetPath() string { + return o.Path +} + +func (o *Object) SetPath(path string) { + o.Path = path +} + +func (o *Object) GetHash() utils.HashInfo { + return o.HashInfo +} + +type Thumbnail struct { + Thumbnail string +} + +type Url struct { + Url string +} + +func (w Url) URL() string { + return w.Url +} + +func (t Thumbnail) Thumb() string { + return t.Thumbnail +} + +type ObjThumb struct { + Object + Thumbnail +} + +type ObjectURL struct { + Object + Url +} + +type ObjThumbURL struct { + Object + Thumbnail + Url +} diff --git a/internal/model/req.go b/internal/model/req.go new file mode 100644 index 0000000000000000000000000000000000000000..fe3a08bd4cb5d88fbc9bb975a462e5ab3df0d3b3 --- /dev/null +++ b/internal/model/req.go @@ -0,0 +1,20 @@ +package model + +type PageReq struct { + Page int `json:"page" form:"page"` + PerPage int `json:"per_page" form:"per_page"` +} + +const MaxUint = ^uint(0) +const MinUint = 0 +const MaxInt = int(MaxUint >> 1) +const MinInt = -MaxInt - 1 + +func (p *PageReq) Validate() { + if p.Page < 1 { + p.Page = 1 + } + if p.PerPage < 1 { + p.PerPage = MaxInt + } +} diff --git a/internal/model/search.go b/internal/model/search.go new file mode 100644 index 0000000000000000000000000000000000000000..1e5c53ee92d7975efb84e15b240545a817736220 --- /dev/null +++ b/internal/model/search.go @@ -0,0 +1,42 @@ +package model + +import ( + "fmt" + "time" +) + +type IndexProgress struct { + ObjCount uint64 `json:"obj_count"` + IsDone bool `json:"is_done"` + LastDoneTime *time.Time `json:"last_done_time"` + Error string `json:"error"` +} + +type SearchReq struct { + Parent string `json:"parent"` + Keywords string `json:"keywords"` + // 0 for all, 1 for dir, 2 for file + Scope int `json:"scope"` + PageReq +} + +type SearchNode struct { + Parent string `json:"parent" gorm:"index"` + Name string `json:"name"` + IsDir bool `json:"is_dir"` + Size int64 `json:"size"` +} + +func (p *SearchReq) Validate() error { + if p.Page < 1 { + return fmt.Errorf("page can't < 1") + } + if p.PerPage < 1 { + return fmt.Errorf("per_page can't < 1") + } + return nil +} + +func (s *SearchNode) Type() string { + return "SearchNode" +} diff --git a/internal/model/setting.go b/internal/model/setting.go new file mode 100644 index 0000000000000000000000000000000000000000..34297820dce85bce3c0cce6bd7d69c92c3095764 --- /dev/null +++ b/internal/model/setting.go @@ -0,0 +1,37 @@ +package model + +const ( + SINGLE = iota + SITE + STYLE + PREVIEW + GLOBAL + OFFLINE_DOWNLOAD + INDEX + SSO + LDAP + S3 + NOTIFICATION +) + +const ( + PUBLIC = iota + PRIVATE + READONLY + DEPRECATED +) + +type SettingItem struct { + Key string `json:"key" gorm:"primaryKey" binding:"required"` // unique key + Value string `json:"value"` // value + PreDefault string `json:"-" gorm:"-:all"` // deprecated value + Help string `json:"help"` // help message + Type string `json:"type"` // string, number, bool, select + Options string `json:"options"` // values for select + Group int `json:"group"` // use to group setting in frontend + Flag int `json:"flag"` // 0 = public, 1 = private, 2 = readonly, 3 = deprecated, etc. +} + +func (s SettingItem) IsDeprecated() bool { + return s.Flag == DEPRECATED +} diff --git a/internal/model/storage.go b/internal/model/storage.go new file mode 100644 index 0000000000000000000000000000000000000000..de4b19a6533baa49a0efda4d616863bbe72b7859 --- /dev/null +++ b/internal/model/storage.go @@ -0,0 +1,58 @@ +package model + +import "time" + +type Storage struct { + ID uint `json:"id" gorm:"primaryKey"` // unique key + MountPath string `json:"mount_path" gorm:"unique" binding:"required"` // must be standardized + Order int `json:"order"` // use to sort + Driver string `json:"driver"` // driver used + CacheExpiration int `json:"cache_expiration"` // cache expire time + Status string `json:"status"` + Addition string `json:"addition" gorm:"type:text"` // Additional information, defined in the corresponding driver + Group string `json:"group"` + SyncGroup bool `json:"sync_group"` // 同步同组所有存储 + Remark string `json:"remark"` + Modified time.Time `json:"modified"` + Disabled bool `json:"disabled"` // if disabled + EnableSign bool `json:"enable_sign"` + Sort + Proxy +} + +type Sort struct { + OrderBy string `json:"order_by"` + OrderDirection string `json:"order_direction"` + ExtractFolder string `json:"extract_folder"` +} + +type Proxy struct { + WebProxy bool `json:"web_proxy"` + WebdavPolicy string `json:"webdav_policy"` + ProxyRange bool `json:"proxy_range"` + DownProxyUrl string `json:"down_proxy_url"` +} + +func (s *Storage) GetStorage() *Storage { + return s +} + +func (s *Storage) SetStorage(storage Storage) { + *s = storage +} + +func (s *Storage) SetStatus(status string) { + s.Status = status +} + +func (p Proxy) Webdav302() bool { + return p.WebdavPolicy == "302_redirect" +} + +func (p Proxy) WebdavProxy() bool { + return p.WebdavPolicy == "use_proxy_url" +} + +func (p Proxy) WebdavNative() bool { + return !p.Webdav302() && !p.WebdavProxy() +} diff --git a/internal/model/task.go b/internal/model/task.go new file mode 100644 index 0000000000000000000000000000000000000000..8a87c5a5062e7cbf5fa202fa3f04ccc96470d306 --- /dev/null +++ b/internal/model/task.go @@ -0,0 +1,6 @@ +package model + +type TaskItem struct { + Key string `json:"key"` + PersistData string `gorm:"type:text" json:"persist_data"` +} diff --git a/internal/model/user.go b/internal/model/user.go new file mode 100644 index 0000000000000000000000000000000000000000..2d61a971c3d8d9de400a891ec6136bb99d140c02 --- /dev/null +++ b/internal/model/user.go @@ -0,0 +1,161 @@ +package model + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "time" + + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/pkg/utils/random" + "github.com/go-webauthn/webauthn/webauthn" + "github.com/pkg/errors" +) + +const ( + GENERAL = iota + GUEST // only one exists + ADMIN +) + +const StaticHashSalt = "https://github.com/alist-org/alist" + +type User struct { + ID uint `json:"id" gorm:"primaryKey"` // unique key + Username string `json:"username" gorm:"unique" binding:"required"` // username + PwdHash string `json:"-"` // password hash + PwdTS int64 `json:"-"` // password timestamp + Salt string `json:"-"` // unique salt + Password string `json:"password"` // password + BasePath string `json:"base_path"` // base path + Role int `json:"role"` // user's role + Disabled bool `json:"disabled"` + // Determine permissions by bit + // 0: can see hidden files + // 1: can access without password + // 2: can add offline download tasks + // 3: can mkdir and upload + // 4: can rename + // 5: can move + // 6: can copy + // 7: can remove + // 8: webdav read + // 9: webdav write + Permission int32 `json:"permission"` + OtpSecret string `json:"-"` + SsoID string `json:"sso_id"` // unique by sso platform + Authn string `gorm:"type:text" json:"-"` +} + +func (u *User) IsGuest() bool { + return u.Role == GUEST +} + +func (u *User) IsAdmin() bool { + return u.Role == ADMIN +} + +func (u *User) ValidateRawPassword(password string) error { + return u.ValidatePwdStaticHash(StaticHash(password)) +} + +func (u *User) ValidatePwdStaticHash(pwdStaticHash string) error { + if pwdStaticHash == "" { + return errors.WithStack(errs.EmptyPassword) + } + if u.PwdHash != HashPwd(pwdStaticHash, u.Salt) { + return errors.WithStack(errs.WrongPassword) + } + return nil +} + +func (u *User) SetPassword(pwd string) *User { + u.Salt = random.String(16) + u.PwdHash = TwoHashPwd(pwd, u.Salt) + u.PwdTS = time.Now().Unix() + return u +} + +func (u *User) CanSeeHides() bool { + return u.IsAdmin() || u.Permission&1 == 1 +} + +func (u *User) CanAccessWithoutPassword() bool { + return u.IsAdmin() || (u.Permission>>1)&1 == 1 +} + +func (u *User) CanAddOfflineDownloadTasks() bool { + return u.IsAdmin() || (u.Permission>>2)&1 == 1 +} + +func (u *User) CanWrite() bool { + return u.IsAdmin() || (u.Permission>>3)&1 == 1 +} + +func (u *User) CanRename() bool { + return u.IsAdmin() || (u.Permission>>4)&1 == 1 +} + +func (u *User) CanMove() bool { + return u.IsAdmin() || (u.Permission>>5)&1 == 1 +} + +func (u *User) CanCopy() bool { + return u.IsAdmin() || (u.Permission>>6)&1 == 1 +} + +func (u *User) CanRemove() bool { + return u.IsAdmin() || (u.Permission>>7)&1 == 1 +} + +func (u *User) CanWebdavRead() bool { + return u.IsAdmin() || (u.Permission>>8)&1 == 1 +} + +func (u *User) CanWebdavManage() bool { + return u.IsAdmin() || (u.Permission>>9)&1 == 1 +} + +func (u *User) JoinPath(reqPath string) (string, error) { + return utils.JoinBasePath(u.BasePath, reqPath) +} + +func StaticHash(password string) string { + return utils.HashData(utils.SHA256, []byte(fmt.Sprintf("%s-%s", password, StaticHashSalt))) +} + +func HashPwd(static string, salt string) string { + return utils.HashData(utils.SHA256, []byte(fmt.Sprintf("%s-%s", static, salt))) +} + +func TwoHashPwd(password string, salt string) string { + return HashPwd(StaticHash(password), salt) +} + +func (u *User) WebAuthnID() []byte { + bs := make([]byte, 8) + binary.LittleEndian.PutUint64(bs, uint64(u.ID)) + return bs +} + +func (u *User) WebAuthnName() string { + return u.Username +} + +func (u *User) WebAuthnDisplayName() string { + return u.Username +} + +func (u *User) WebAuthnCredentials() []webauthn.Credential { + var res []webauthn.Credential + err := json.Unmarshal([]byte(u.Authn), &res) + if err != nil { + fmt.Println(err) + } + return res +} + +func (u *User) WebAuthnIcon() string { + return "https://alist.nn.ci/logo.svg" +} diff --git a/internal/net/request.go b/internal/net/request.go new file mode 100644 index 0000000000000000000000000000000000000000..71f45aa7afc709b7d1d9c7d5f4b8380a9a9c11f3 --- /dev/null +++ b/internal/net/request.go @@ -0,0 +1,525 @@ +package net + +import ( + "bytes" + "context" + "fmt" + "io" + "math" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/aws/aws-sdk-go/aws/awsutil" + log "github.com/sirupsen/logrus" +) + +// DefaultDownloadPartSize is the default range of bytes to get at a time when +// using Download(). +const DefaultDownloadPartSize = 1024 * 1024 * 10 + +// DefaultDownloadConcurrency is the default number of goroutines to spin up +// when using Download(). +const DefaultDownloadConcurrency = 2 + +// DefaultPartBodyMaxRetries is the default number of retries to make when a part fails to download. +const DefaultPartBodyMaxRetries = 3 + +type Downloader struct { + PartSize int + + // PartBodyMaxRetries is the number of retry attempts to make for failed part downloads. + PartBodyMaxRetries int + + // The number of goroutines to spin up in parallel when sending parts. + // If this is set to zero, the DefaultDownloadConcurrency value will be used. + // + // Concurrency of 1 will download the parts sequentially. + Concurrency int + + //RequestParam HttpRequestParams + HttpClient HttpRequestFunc +} +type HttpRequestFunc func(ctx context.Context, params *HttpRequestParams) (*http.Response, error) + +func NewDownloader(options ...func(*Downloader)) *Downloader { + d := &Downloader{ + HttpClient: DefaultHttpRequestFunc, + PartSize: DefaultDownloadPartSize, + PartBodyMaxRetries: DefaultPartBodyMaxRetries, + Concurrency: DefaultDownloadConcurrency, + } + for _, option := range options { + option(d) + } + return d +} + +// Download The Downloader makes multi-thread http requests to remote URL, each chunk(except last one) has PartSize, +// cache some data, then return Reader with assembled data +// Supports range, do not support unknown FileSize, and will fail if FileSize is incorrect +// memory usage is at about Concurrency*PartSize, use this wisely +func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readCloser io.ReadCloser, err error) { + + var finalP HttpRequestParams + awsutil.Copy(&finalP, p) + if finalP.Range.Length == -1 { + finalP.Range.Length = finalP.Size - finalP.Range.Start + } + impl := downloader{params: &finalP, cfg: d, ctx: ctx} + + // Ensures we don't need nil checks later on + + impl.partBodyMaxRetries = d.PartBodyMaxRetries + + if impl.cfg.Concurrency == 0 { + impl.cfg.Concurrency = DefaultDownloadConcurrency + } + + if impl.cfg.PartSize == 0 { + impl.cfg.PartSize = DefaultDownloadPartSize + } + + return impl.download() +} + +// downloader is the implementation structure used internally by Downloader. +type downloader struct { + ctx context.Context + cancel context.CancelFunc + cfg Downloader + + params *HttpRequestParams //http request params + chunkChannel chan chunk //chunk chanel + + //wg sync.WaitGroup + m sync.Mutex + + nextChunk int //next chunk id + chunks []chunk + bufs []*Buf + //totalBytes int64 + written int64 //total bytes of file downloaded from remote + err error + + partBodyMaxRetries int +} + +// download performs the implementation of the object download across ranged GETs. +func (d *downloader) download() (io.ReadCloser, error) { + d.ctx, d.cancel = context.WithCancel(d.ctx) + + pos := d.params.Range.Start + maxPos := d.params.Range.Start + d.params.Range.Length + id := 0 + for pos < maxPos { + finalSize := int64(d.cfg.PartSize) + //check boundary + if pos+finalSize > maxPos { + finalSize = maxPos - pos + } + c := chunk{start: pos, size: finalSize, id: id} + d.chunks = append(d.chunks, c) + pos += finalSize + id++ + } + if len(d.chunks) < d.cfg.Concurrency { + d.cfg.Concurrency = len(d.chunks) + } + + if d.cfg.Concurrency == 1 { + resp, err := d.cfg.HttpClient(d.ctx, d.params) + if err != nil { + return nil, err + } + return resp.Body, nil + } + + // workers + d.chunkChannel = make(chan chunk, d.cfg.Concurrency) + + for i := 0; i < d.cfg.Concurrency; i++ { + buf := NewBuf(d.ctx, d.cfg.PartSize, i) + d.bufs = append(d.bufs, buf) + go d.downloadPart() + } + // initial tasks + for i := 0; i < d.cfg.Concurrency; i++ { + d.sendChunkTask() + } + + var rc io.ReadCloser = NewMultiReadCloser(d.chunks[0].buf, d.interrupt, d.finishBuf) + + // Return error + return rc, d.err +} +func (d *downloader) sendChunkTask() *chunk { + ch := &d.chunks[d.nextChunk] + ch.buf = d.getBuf(d.nextChunk) + ch.buf.Reset(int(ch.size)) + d.chunkChannel <- *ch + d.nextChunk++ + return ch +} + +// when the final reader Close, we interrupt +func (d *downloader) interrupt() error { + d.cancel() + if d.written != d.params.Range.Length { + log.Debugf("Downloader interrupt before finish") + if d.getErr() == nil { + d.setErr(fmt.Errorf("interrupted")) + } + } + defer func() { + close(d.chunkChannel) + for _, buf := range d.bufs { + buf.Close() + } + }() + return d.err +} +func (d *downloader) getBuf(id int) (b *Buf) { + + return d.bufs[id%d.cfg.Concurrency] +} +func (d *downloader) finishBuf(id int) (isLast bool, buf *Buf) { + if id >= len(d.chunks)-1 { + return true, nil + } + if d.nextChunk > id+1 { + return false, d.getBuf(id + 1) + } + ch := d.sendChunkTask() + return false, ch.buf +} + +// downloadPart is an individual goroutine worker reading from the ch channel +// and performing Http request on the data with a given byte range. +func (d *downloader) downloadPart() { + //defer d.wg.Done() + for { + c, ok := <-d.chunkChannel + if !ok { + break + } + if d.getErr() != nil { + // Drain the channel if there is an error, to prevent deadlocking + // of download producer. + continue + } + log.Debugf("downloadPart tried to get chunk") + if err := d.downloadChunk(&c); err != nil { + d.setErr(err) + } + } +} + +// downloadChunk downloads the chunk +func (d *downloader) downloadChunk(ch *chunk) error { + log.Debugf("start new chunk %+v buffer_id =%d", ch, ch.id) + var n int64 + var err error + params := d.getParamsFromChunk(ch) + for retry := 0; retry <= d.partBodyMaxRetries; retry++ { + if d.getErr() != nil { + return d.getErr() + } + n, err = d.tryDownloadChunk(params, ch) + if err == nil { + break + } + // Check if the returned error is an errReadingBody. + // If err is errReadingBody this indicates that an error + // occurred while copying the http response body. + // If this occurs we unwrap the err to set the underlying error + // and attempt any remaining retries. + if bodyErr, ok := err.(*errReadingBody); ok { + err = bodyErr.Unwrap() + } else { + return err + } + + //ch.cur = 0 + + log.Debugf("object part body download interrupted %s, err, %v, retrying attempt %d", + params.URL, err, retry) + } + + d.incrWritten(n) + log.Debugf("down_%d downloaded chunk", ch.id) + //ch.buf.buffer.wg1.Wait() + //log.Debugf("down_%d downloaded chunk,wg wait passed", ch.id) + return err +} + +func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int64, error) { + + resp, err := d.cfg.HttpClient(d.ctx, params) + if err != nil { + return 0, err + } + defer resp.Body.Close() + //only check file size on the first task + if ch.id == 0 { + err = d.checkTotalBytes(resp) + if err != nil { + return 0, err + } + } + + n, err := io.Copy(ch.buf, resp.Body) + + if err != nil { + return n, &errReadingBody{err: err} + } + if n != ch.size { + err = fmt.Errorf("chunk download size incorrect, expected=%d, got=%d", ch.size, n) + return n, &errReadingBody{err: err} + } + + return n, nil +} +func (d *downloader) getParamsFromChunk(ch *chunk) *HttpRequestParams { + var params HttpRequestParams + awsutil.Copy(¶ms, d.params) + + // Get the getBuf byte range of data + params.Range = http_range.Range{Start: ch.start, Length: ch.size} + return ¶ms +} + +func (d *downloader) checkTotalBytes(resp *http.Response) error { + var err error + var totalBytes int64 = math.MinInt64 + contentRange := resp.Header.Get("Content-Range") + if len(contentRange) == 0 { + // ContentRange is nil when the full file contents is provided, and + // is not chunked. Use ContentLength instead. + if resp.ContentLength > 0 { + totalBytes = resp.ContentLength + } + } else { + parts := strings.Split(contentRange, "/") + + total := int64(-1) + + // Checking for whether a numbered total exists + // If one does not exist, we will assume the total to be -1, undefined, + // and sequentially download each chunk until hitting a 416 error + totalStr := parts[len(parts)-1] + if totalStr != "*" { + total, err = strconv.ParseInt(totalStr, 10, 64) + if err != nil { + err = fmt.Errorf("failed extracting file size") + } + } else { + err = fmt.Errorf("file size unknown") + } + + totalBytes = total + } + if totalBytes != d.params.Size && err == nil { + err = fmt.Errorf("expect file size=%d unmatch remote report size=%d, need refresh cache", d.params.Size, totalBytes) + } + if err != nil { + _ = d.interrupt() + d.setErr(err) + } + return err + +} + +func (d *downloader) incrWritten(n int64) { + d.m.Lock() + defer d.m.Unlock() + + d.written += n +} + +// getErr is a thread-safe getter for the error object +func (d *downloader) getErr() error { + d.m.Lock() + defer d.m.Unlock() + + return d.err +} + +// setErr is a thread-safe setter for the error object +func (d *downloader) setErr(e error) { + d.m.Lock() + defer d.m.Unlock() + + d.err = e +} + +// Chunk represents a single chunk of data to write by the worker routine. +// This structure also implements an io.SectionReader style interface for +// io.WriterAt, effectively making it an io.SectionWriter (which does not +// exist). +type chunk struct { + start int64 + size int64 + buf *Buf + id int + + // Downloader takes range (start,length), but this chunk is requesting equal/sub range of it. + // To convert the writer to reader eventually, we need to write within the boundary + //boundary http_range.Range +} + +func DefaultHttpRequestFunc(ctx context.Context, params *HttpRequestParams) (*http.Response, error) { + header := http_range.ApplyRangeToHttpHeader(params.Range, params.HeaderRef) + + res, err := RequestHttp(ctx, "GET", header, params.URL) + if err != nil { + return nil, err + } + return res, nil +} + +type HttpRequestParams struct { + URL string + //only want data within this range + Range http_range.Range + HeaderRef http.Header + //total file size + Size int64 +} +type errReadingBody struct { + err error +} + +func (e *errReadingBody) Error() string { + return fmt.Sprintf("failed to read part body: %v", e.err) +} + +func (e *errReadingBody) Unwrap() error { + return e.err +} + +type MultiReadCloser struct { + cfg *cfg + closer closerFunc + finish finishBufFUnc +} + +type cfg struct { + rPos int //current reader position, start from 0 + curBuf *Buf +} + +type closerFunc func() error +type finishBufFUnc func(id int) (isLast bool, buf *Buf) + +// NewMultiReadCloser to save memory, we re-use limited Buf, and feed data to Read() +func NewMultiReadCloser(buf *Buf, c closerFunc, fb finishBufFUnc) *MultiReadCloser { + return &MultiReadCloser{closer: c, finish: fb, cfg: &cfg{curBuf: buf}} +} + +func (mr MultiReadCloser) Read(p []byte) (n int, err error) { + if mr.cfg.curBuf == nil { + return 0, io.EOF + } + n, err = mr.cfg.curBuf.Read(p) + //log.Debugf("read_%d read current buffer, n=%d ,err=%+v", mr.cfg.rPos, n, err) + if err == io.EOF { + log.Debugf("read_%d finished current buffer", mr.cfg.rPos) + + isLast, next := mr.finish(mr.cfg.rPos) + if isLast { + return n, io.EOF + } + mr.cfg.curBuf = next + mr.cfg.rPos++ + //current.Close() + return n, nil + } + return n, err +} +func (mr MultiReadCloser) Close() error { + return mr.closer() +} + +type Buf struct { + buffer *bytes.Buffer + size int //expected size + ctx context.Context + off int + rw sync.Mutex + //notify chan struct{} +} + +// NewBuf is a buffer that can have 1 read & 1 write at the same time. +// when read is faster write, immediately feed data to read after written +func NewBuf(ctx context.Context, maxSize int, id int) *Buf { + d := make([]byte, 0, maxSize) + return &Buf{ + ctx: ctx, + buffer: bytes.NewBuffer(d), + size: maxSize, + //notify: make(chan struct{}), + } +} +func (br *Buf) Reset(size int) { + br.buffer.Reset() + br.size = size + br.off = 0 +} + +func (br *Buf) Read(p []byte) (n int, err error) { + if err := br.ctx.Err(); err != nil { + return 0, err + } + if len(p) == 0 { + return 0, nil + } + if br.off >= br.size { + return 0, io.EOF + } + br.rw.Lock() + n, err = br.buffer.Read(p) + br.rw.Unlock() + if err == nil { + br.off += n + return n, err + } + if err != io.EOF { + return n, err + } + if n != 0 { + br.off += n + return n, nil + } + // n==0, err==io.EOF + // wait for new write for 200ms + select { + case <-br.ctx.Done(): + return 0, br.ctx.Err() + //case <-br.notify: + // return 0, nil + case <-time.After(time.Millisecond * 200): + return 0, nil + } +} + +func (br *Buf) Write(p []byte) (n int, err error) { + if err := br.ctx.Err(); err != nil { + return 0, err + } + br.rw.Lock() + defer br.rw.Unlock() + n, err = br.buffer.Write(p) + select { + //case br.notify <- struct{}{}: + default: + } + return +} + +func (br *Buf) Close() { + //close(br.notify) +} diff --git a/internal/net/request_test.go b/internal/net/request_test.go new file mode 100644 index 0000000000000000000000000000000000000000..032b7376585fda5c000db103d1369210216180eb --- /dev/null +++ b/internal/net/request_test.go @@ -0,0 +1,178 @@ +package net + +//no http range +// + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "sync" + "testing" + + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/sirupsen/logrus" + "golang.org/x/exp/slices" +) + +var buf22MB = make([]byte, 1024*1024*22) + +func dummyHttpRequest(data []byte, p http_range.Range) io.ReadCloser { + + end := p.Start + p.Length - 1 + + if end >= int64(len(data)) { + end = int64(len(data)) + } + + bodyBytes := data[p.Start:end] + return io.NopCloser(bytes.NewReader(bodyBytes)) +} + +func TestDownloadOrder(t *testing.T) { + buff := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + downloader, invocations, ranges := newDownloadRangeClient(buff) + con, partSize := 3, 3 + d := NewDownloader(func(d *Downloader) { + d.Concurrency = con + d.PartSize = partSize + d.HttpClient = downloader.HttpRequest + }) + + var start, length int64 = 2, 10 + length2 := length + if length2 == -1 { + length2 = int64(len(buff)) - start + } + req := &HttpRequestParams{ + Range: http_range.Range{Start: start, Length: length}, + Size: int64(len(buff)), + } + readCloser, err := d.Download(context.Background(), req) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + resultBuf, err := io.ReadAll(readCloser) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if exp, a := int(length), len(resultBuf); exp != a { + t.Errorf("expect buffer length=%d, got %d", exp, a) + } + chunkSize := int(length)/partSize + 1 + if int(length)%partSize == 0 { + chunkSize-- + } + if e, a := chunkSize, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + + expectRngs := []string{"2-3", "5-3", "8-3", "11-1"} + for _, rng := range expectRngs { + if !slices.Contains(*ranges, rng) { + t.Errorf("expect range %v, but absent in return", rng) + } + } + if e, a := expectRngs, *ranges; len(e) != len(a) { + t.Errorf("expect %v ranges, got %v", e, a) + } +} +func init() { + Formatter := new(logrus.TextFormatter) + Formatter.TimestampFormat = "2006-01-02T15:04:05.999999999" + Formatter.FullTimestamp = true + Formatter.ForceColors = true + logrus.SetFormatter(Formatter) + logrus.SetLevel(logrus.DebugLevel) + logrus.Debugf("Download start") +} + +func TestDownloadSingle(t *testing.T) { + buff := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + downloader, invocations, ranges := newDownloadRangeClient(buff) + con, partSize := 1, 3 + d := NewDownloader(func(d *Downloader) { + d.Concurrency = con + d.PartSize = partSize + d.HttpClient = downloader.HttpRequest + }) + + var start, length int64 = 2, 10 + req := &HttpRequestParams{ + Range: http_range.Range{Start: start, Length: length}, + Size: int64(len(buff)), + } + + readCloser, err := d.Download(context.Background(), req) + + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + resultBuf, err := io.ReadAll(readCloser) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if exp, a := int(length), len(resultBuf); exp != a { + t.Errorf("expect buffer length=%d, got %d", exp, a) + } + if e, a := 1, *invocations; e != a { + t.Errorf("expect %v API calls, got %v", e, a) + } + + expectRngs := []string{"2-10"} + for _, rng := range expectRngs { + if !slices.Contains(*ranges, rng) { + t.Errorf("expect range %v, but absent in return", rng) + } + } + if e, a := expectRngs, *ranges; len(e) != len(a) { + t.Errorf("expect %v ranges, got %v", e, a) + } +} + +type downloadCaptureClient struct { + mockedHttpRequest func(params *HttpRequestParams) (*http.Response, error) + GetObjectInvocations int + + RetrievedRanges []string + + lock sync.Mutex +} + +func (c *downloadCaptureClient) HttpRequest(ctx context.Context, params *HttpRequestParams) (*http.Response, error) { + c.lock.Lock() + defer c.lock.Unlock() + + c.GetObjectInvocations++ + + if ¶ms.Range != nil { + c.RetrievedRanges = append(c.RetrievedRanges, fmt.Sprintf("%d-%d", params.Range.Start, params.Range.Length)) + } + + return c.mockedHttpRequest(params) +} + +func newDownloadRangeClient(data []byte) (*downloadCaptureClient, *int, *[]string) { + capture := &downloadCaptureClient{} + + capture.mockedHttpRequest = func(params *HttpRequestParams) (*http.Response, error) { + start, fin := params.Range.Start, params.Range.Start+params.Range.Length + if params.Range.Length == -1 || fin >= int64(len(data)) { + fin = int64(len(data)) + } + bodyBytes := data[start:fin] + + header := &http.Header{} + header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, fin-1, len(data))) + return &http.Response{ + Body: io.NopCloser(bytes.NewReader(bodyBytes)), + Header: *header, + ContentLength: int64(len(bodyBytes)), + }, nil + } + + return capture, &capture.GetObjectInvocations, &capture.RetrievedRanges +} diff --git a/internal/net/serve.go b/internal/net/serve.go new file mode 100644 index 0000000000000000000000000000000000000000..a05667807593528127257319128aabbf8dd199dc --- /dev/null +++ b/internal/net/serve.go @@ -0,0 +1,250 @@ +package net + +import ( + "context" + "fmt" + "io" + "mime" + "mime/multipart" + "net/http" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +//this file is inspired by GO_SDK net.http.ServeContent + +//type RangeReadCloser struct { +// GetReaderForRange RangeReaderFunc +//} + +// ServeHTTP replies to the request using the content in the +// provided RangeReadCloser. The main benefit of ServeHTTP over io.Copy +// is that it handles Range requests properly, sets the MIME type, and +// handles If-Match, If-Unmodified-Since, If-None-Match, If-Modified-Since, +// and If-Range requests. +// +// If the response's Content-Type header is not set, ServeHTTP +// first tries to deduce the type from name's file extension and, +// if that fails, falls back to reading the first block of the content +// and passing it to DetectContentType. +// The name is otherwise unused; in particular it can be empty and is +// never sent in the response. +// +// If modtime is not the zero time or Unix epoch, ServeHTTP +// includes it in a Last-Modified header in the response. If the +// request includes an If-Modified-Since header, ServeHTTP uses +// modtime to decide whether the content needs to be sent at all. +// +// The content's RangeReadCloser method must work: ServeHTTP gives a range, +// caller will give the reader for that Range. +// +// If the caller has set w's ETag header formatted per RFC 7232, section 2.3, +// ServeHTTP uses it to handle requests using If-Match, If-None-Match, or If-Range. +func ServeHTTP(w http.ResponseWriter, r *http.Request, name string, modTime time.Time, size int64, RangeReaderFunc model.RangeReaderFunc) { + setLastModified(w, modTime) + done, rangeReq := checkPreconditions(w, r, modTime) + if done { + return + } + + if size < 0 { + // since too many functions need file size to work, + // will not implement the support of unknown file size here + http.Error(w, "negative content size not supported", http.StatusInternalServerError) + return + } + + code := http.StatusOK + + // If Content-Type isn't set, use the file's extension to find it, but + // if the Content-Type is unset explicitly, do not sniff the type. + contentTypes, haveType := w.Header()["Content-Type"] + var contentType string + if !haveType { + contentType = mime.TypeByExtension(filepath.Ext(name)) + if contentType == "" { + // most modern application can handle the default contentType + contentType = "application/octet-stream" + } + w.Header().Set("Content-Type", contentType) + } else if len(contentTypes) > 0 { + contentType = contentTypes[0] + } + + // handle Content-Range header. + sendSize := size + var sendContent io.ReadCloser + ranges, err := http_range.ParseRange(rangeReq, size) + switch err { + case nil: + case http_range.ErrNoOverlap: + if size == 0 { + // Some clients add a Range header to all requests to + // limit the size of the response. If the file is empty, + // ignore the range header and respond with a 200 rather + // than a 416. + ranges = nil + break + } + w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size)) + fallthrough + default: + http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable) + return + } + + if sumRangesSize(ranges) > size || size < 0 { + // The total number of bytes in all the ranges is larger than the size of the file + // or unknown file size, ignore the range request. + ranges = nil + } + switch { + case len(ranges) == 0: + reader, err := RangeReaderFunc(context.Background(), http_range.Range{Length: -1}) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + sendContent = reader + case len(ranges) == 1: + // RFC 7233, Section 4.1: + // "If a single part is being transferred, the server + // generating the 206 response MUST generate a + // Content-Range header field, describing what range + // of the selected representation is enclosed, and a + // payload consisting of the range. + // ... + // A server MUST NOT generate a multipart response to + // a request for a single range, since a client that + // does not request multiple parts might not support + // multipart responses." + ra := ranges[0] + sendContent, err = RangeReaderFunc(context.Background(), ra) + if err != nil { + http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable) + return + } + sendSize = ra.Length + code = http.StatusPartialContent + w.Header().Set("Content-Range", ra.ContentRange(size)) + case len(ranges) > 1: + sendSize, err = rangesMIMESize(ranges, contentType, size) + if err != nil { + http.Error(w, err.Error(), http.StatusRequestedRangeNotSatisfiable) + } + code = http.StatusPartialContent + + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) + sendContent = pr + defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. + go func() { + for _, ra := range ranges { + part, err := mw.CreatePart(ra.MimeHeader(contentType, size)) + if err != nil { + pw.CloseWithError(err) + return + } + reader, err := RangeReaderFunc(context.Background(), ra) + if err != nil { + pw.CloseWithError(err) + return + } + if _, err := io.CopyN(part, reader, ra.Length); err != nil { + pw.CloseWithError(err) + return + } + //defer reader.Close() + } + + mw.Close() + pw.Close() + }() + } + + w.Header().Set("Accept-Ranges", "bytes") + if w.Header().Get("Content-Encoding") == "" { + w.Header().Set("Content-Length", strconv.FormatInt(sendSize, 10)) + } + + w.WriteHeader(code) + + if r.Method != "HEAD" { + written, err := io.CopyN(w, sendContent, sendSize) + if err != nil { + log.Warnf("ServeHttp error. err: %s ", err) + if written != sendSize { + log.Warnf("Maybe size incorrect or reader not giving correct/full data, or connection closed before finish. written bytes: %d ,sendSize:%d, ", written, sendSize) + } + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } + //defer sendContent.Close() +} +func ProcessHeader(origin, override http.Header) http.Header { + result := http.Header{} + // client header + for h, val := range origin { + if utils.SliceContains(conf.SlicesMap[conf.ProxyIgnoreHeaders], strings.ToLower(h)) { + continue + } + result[h] = val + } + // needed header + for h, val := range override { + result[h] = val + } + return result +} + +// RequestHttp deal with Header properly then send the request +func RequestHttp(ctx context.Context, httpMethod string, headerOverride http.Header, URL string) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, httpMethod, URL, nil) + if err != nil { + return nil, err + } + req.Header = headerOverride + res, err := HttpClient().Do(req) + if err != nil { + return nil, err + } + // TODO clean header with blocklist or passlist + res.Header.Del("set-cookie") + if res.StatusCode >= 400 { + all, _ := io.ReadAll(res.Body) + _ = res.Body.Close() + msg := string(all) + log.Debugln(msg) + return nil, fmt.Errorf("http request [%s] failure,status: %d response:%s", URL, res.StatusCode, msg) + } + return res, nil +} + +var once sync.Once +var httpClient *http.Client + +func HttpClient() *http.Client { + once.Do(func() { + httpClient = base.NewHttpClient() + httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + req.Header.Del("Referer") + return nil + } + }) + return httpClient +} diff --git a/internal/net/util.go b/internal/net/util.go new file mode 100644 index 0000000000000000000000000000000000000000..4347e2c404df56b6ed9f36121fa173cbe3171afb --- /dev/null +++ b/internal/net/util.go @@ -0,0 +1,339 @@ +package net + +import ( + "fmt" + "io" + "math" + "mime/multipart" + "net/http" + "net/textproto" + "strings" + "time" + + "github.com/alist-org/alist/v3/pkg/http_range" + log "github.com/sirupsen/logrus" +) + +// scanETag determines if a syntactically valid ETag is present at s. If so, +// the ETag and remaining text after consuming ETag is returned. Otherwise, +// it returns "", "". +func scanETag(s string) (etag string, remain string) { + s = textproto.TrimString(s) + start := 0 + if strings.HasPrefix(s, "W/") { + start = 2 + } + if len(s[start:]) < 2 || s[start] != '"' { + return "", "" + } + // ETag is either W/"text" or "text". + // See RFC 7232 2.3. + for i := start + 1; i < len(s); i++ { + c := s[i] + switch { + // Character values allowed in ETags. + case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: + case c == '"': + return s[:i+1], s[i+1:] + default: + return "", "" + } + } + return "", "" +} + +// etagStrongMatch reports whether a and b match using strong ETag comparison. +// Assumes a and b are valid ETags. +func etagStrongMatch(a, b string) bool { + return a == b && a != "" && a[0] == '"' +} + +// etagWeakMatch reports whether a and b match using weak ETag comparison. +// Assumes a and b are valid ETags. +func etagWeakMatch(a, b string) bool { + return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/") +} + +// condResult is the result of an HTTP request precondition check. +// See https://tools.ietf.org/html/rfc7232 section 3. +type condResult int + +const ( + condNone condResult = iota + condTrue + condFalse +) + +func checkIfMatch(w http.ResponseWriter, r *http.Request) condResult { + im := r.Header.Get("If-Match") + if im == "" { + return condNone + } + for { + im = textproto.TrimString(im) + if len(im) == 0 { + break + } + if im[0] == ',' { + im = im[1:] + continue + } + if im[0] == '*' { + return condTrue + } + etag, remain := scanETag(im) + if etag == "" { + break + } + if etagStrongMatch(etag, w.Header().Get("Etag")) { + return condTrue + } + im = remain + } + + return condFalse +} + +func checkIfUnmodifiedSince(r *http.Request, modtime time.Time) condResult { + ius := r.Header.Get("If-Unmodified-Since") + if ius == "" || isZeroTime(modtime) { + return condNone + } + t, err := http.ParseTime(ius) + if err != nil { + return condNone + } + + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + modtime = modtime.Truncate(time.Second) + if ret := modtime.Compare(t); ret <= 0 { + return condTrue + } + return condFalse +} + +func checkIfNoneMatch(w http.ResponseWriter, r *http.Request) condResult { + inm := r.Header.Get("If-None-Match") + if inm == "" { + return condNone + } + buf := inm + for { + buf = textproto.TrimString(buf) + if len(buf) == 0 { + break + } + if buf[0] == ',' { + buf = buf[1:] + continue + } + if buf[0] == '*' { + return condFalse + } + etag, remain := scanETag(buf) + if etag == "" { + break + } + if etagWeakMatch(etag, w.Header().Get("Etag")) { + return condFalse + } + buf = remain + } + return condTrue +} + +func checkIfModifiedSince(r *http.Request, modtime time.Time) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ims := r.Header.Get("If-Modified-Since") + if ims == "" || isZeroTime(modtime) { + return condNone + } + t, err := http.ParseTime(ims) + if err != nil { + return condNone + } + // The Last-Modified header truncates sub-second precision so + // the modtime needs to be truncated too. + modtime = modtime.Truncate(time.Second) + if ret := modtime.Compare(t); ret <= 0 { + return condFalse + } + return condTrue +} + +func checkIfRange(w http.ResponseWriter, r *http.Request, modtime time.Time) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ir := r.Header.Get("If-Range") + if ir == "" { + return condNone + } + etag, _ := scanETag(ir) + if etag != "" { + if etagStrongMatch(etag, w.Header().Get("Etag")) { + return condTrue + } + return condFalse + } + // The If-Range value is typically the ETag value, but it may also be + // the modtime date. See golang.org/issue/8367. + if modtime.IsZero() { + return condFalse + } + t, err := http.ParseTime(ir) + if err != nil { + return condFalse + } + if t.Unix() == modtime.Unix() { + return condTrue + } + return condFalse +} + +var unixEpochTime = time.Unix(0, 0) + +// isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0). +func isZeroTime(t time.Time) bool { + return t.IsZero() || t.Equal(unixEpochTime) +} + +func setLastModified(w http.ResponseWriter, modtime time.Time) { + if !isZeroTime(modtime) { + w.Header().Set("Last-Modified", modtime.UTC().Format(http.TimeFormat)) + } +} + +func writeNotModified(w http.ResponseWriter) { + // RFC 7232 section 4.1: + // a sender SHOULD NOT generate representation metadata other than the + // above listed fields unless said metadata exists for the purpose of + // guiding cache updates (e.g., Last-Modified might be useful if the + // response does not have an ETag field). + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + delete(h, "Content-Encoding") + if h.Get("Etag") != "" { + delete(h, "Last-Modified") + } + w.WriteHeader(http.StatusNotModified) +} + +// checkPreconditions evaluates request preconditions and reports whether a precondition +// resulted in sending StatusNotModified or StatusPreconditionFailed. +func checkPreconditions(w http.ResponseWriter, r *http.Request, modtime time.Time) (done bool, rangeHeader string) { + // This function carefully follows RFC 7232 section 6. + ch := checkIfMatch(w, r) + if ch == condNone { + ch = checkIfUnmodifiedSince(r, modtime) + } + if ch == condFalse { + w.WriteHeader(http.StatusPreconditionFailed) + return true, "" + } + switch checkIfNoneMatch(w, r) { + case condFalse: + if r.Method == "GET" || r.Method == "HEAD" { + writeNotModified(w) + return true, "" + } + w.WriteHeader(http.StatusPreconditionFailed) + return true, "" + case condNone: + if checkIfModifiedSince(r, modtime) == condFalse { + writeNotModified(w) + return true, "" + } + } + + rangeHeader = r.Header.Get("Range") + if rangeHeader != "" && checkIfRange(w, r, modtime) == condFalse { + rangeHeader = "" + } + return false, rangeHeader +} + +func sumRangesSize(ranges []http_range.Range) (size int64) { + for _, ra := range ranges { + size += ra.Length + } + return +} + +// countingWriter counts how many bytes have been written to it. +type countingWriter int64 + +func (w *countingWriter) Write(p []byte) (n int, err error) { + *w += countingWriter(len(p)) + return len(p), nil +} + +// rangesMIMESize returns the number of bytes it takes to encode the +// provided ranges as a multipart response. +func rangesMIMESize(ranges []http_range.Range, contentType string, contentSize int64) (encSize int64, err error) { + var w countingWriter + mw := multipart.NewWriter(&w) + for _, ra := range ranges { + _, err := mw.CreatePart(ra.MimeHeader(contentType, contentSize)) + if err != nil { + return 0, err + } + encSize += ra.Length + } + err = mw.Close() + if err != nil { + return 0, err + } + encSize += int64(w) + return encSize, nil +} + +// LimitedReadCloser wraps a io.ReadCloser and limits the number of bytes that can be read from it. +type LimitedReadCloser struct { + rc io.ReadCloser + remaining int +} + +func (l *LimitedReadCloser) Read(buf []byte) (int, error) { + if l.remaining <= 0 { + return 0, io.EOF + } + + if len(buf) > l.remaining { + buf = buf[0:l.remaining] + } + + n, err := l.rc.Read(buf) + l.remaining -= n + + return n, err +} + +func (l *LimitedReadCloser) Close() error { + return l.rc.Close() +} + +// GetRangedHttpReader some http server doesn't support "Range" header, +// so this function read readCloser with whole data, skip offset, then return ReaderCloser. +func GetRangedHttpReader(readCloser io.ReadCloser, offset, length int64) (io.ReadCloser, error) { + var length_int int + if length > math.MaxInt { + return nil, fmt.Errorf("doesnot support length bigger than int32 max ") + } + length_int = int(length) + + if offset > 100*1024*1024 { + log.Warnf("offset is more than 100MB, if loading data from internet, high-latency and wasting of bandwidth is expected") + } + + if _, err := io.Copy(io.Discard, io.LimitReader(readCloser, offset)); err != nil { + return nil, err + } + + // return an io.ReadCloser that is limited to `length` bytes. + return &LimitedReadCloser{readCloser, length_int}, nil +} diff --git a/internal/offline_download/all.go b/internal/offline_download/all.go new file mode 100644 index 0000000000000000000000000000000000000000..6d91d5f6eb0d8d9785c2dc491cf65e4568452802 --- /dev/null +++ b/internal/offline_download/all.go @@ -0,0 +1,8 @@ +package offline_download + +import ( + _ "github.com/alist-org/alist/v3/internal/offline_download/aria2" + _ "github.com/alist-org/alist/v3/internal/offline_download/http" + _ "github.com/alist-org/alist/v3/internal/offline_download/qbit" + _ "github.com/alist-org/alist/v3/internal/offline_download/storage" +) diff --git a/internal/offline_download/aria2/aria2.go b/internal/offline_download/aria2/aria2.go new file mode 100644 index 0000000000000000000000000000000000000000..d22b32f9d556b6c2f42b78002e7a216f00f0ebdc --- /dev/null +++ b/internal/offline_download/aria2/aria2.go @@ -0,0 +1,127 @@ +package aria2 + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/alist-org/alist/v3/internal/errs" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/offline_download/tool" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/aria2/rpc" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +var notify = NewNotify() + +type Aria2 struct { + client rpc.Client +} + +func (a *Aria2) Run(task *tool.DownloadTask) error { + return errs.NotSupport +} + +func (a *Aria2) Name() string { + return "aria2" +} + +func (a *Aria2) Items() []model.SettingItem { + // aria2 settings + return []model.SettingItem{ + {Key: conf.Aria2Uri, Value: "http://localhost:6800/jsonrpc", Type: conf.TypeString, Group: model.OFFLINE_DOWNLOAD, Flag: model.PRIVATE}, + {Key: conf.Aria2Secret, Value: "", Type: conf.TypeString, Group: model.OFFLINE_DOWNLOAD, Flag: model.PRIVATE}, + } +} + +func (a *Aria2) Init() (string, error) { + a.client = nil + uri := setting.GetStr(conf.Aria2Uri) + secret := setting.GetStr(conf.Aria2Secret) + c, err := rpc.New(context.Background(), uri, secret, 4*time.Second, notify) + if err != nil { + return "", errors.Wrap(err, "failed to init aria2 client") + } + version, err := c.GetVersion() + if err != nil { + return "", errors.Wrapf(err, "failed get aria2 version") + } + a.client = c + log.Infof("using aria2 version: %s", version.Version) + return fmt.Sprintf("aria2 version: %s", version.Version), nil +} + +func (a *Aria2) IsReady() bool { + return a.client != nil +} + +func (a *Aria2) AddURL(args *tool.AddUrlArgs) (string, error) { + options := map[string]interface{}{ + "dir": args.TempDir, + } + gid, err := a.client.AddURI([]string{args.Url}, options) + if err != nil { + return "", err + } + notify.Signals.Store(gid, args.Signal) + return gid, nil +} + +func (a *Aria2) Remove(task *tool.DownloadTask) error { + _, err := a.client.Remove(task.GID) + return err +} + +func (a *Aria2) Status(task *tool.DownloadTask) (*tool.Status, error) { + info, err := a.client.TellStatus(task.GID) + if err != nil { + return nil, err + } + total, err := strconv.ParseUint(info.TotalLength, 10, 64) + if err != nil { + total = 0 + } + downloaded, err := strconv.ParseUint(info.CompletedLength, 10, 64) + if err != nil { + downloaded = 0 + } + s := &tool.Status{ + Completed: info.Status == "complete", + Err: err, + } + s.Progress = float64(downloaded) / float64(total) * 100 + if len(info.FollowedBy) != 0 { + s.NewGID = info.FollowedBy[0] + notify.Signals.Delete(task.GID) + notify.Signals.Store(s.NewGID, task.Signal) + } + switch info.Status { + case "complete": + s.Completed = true + case "error": + s.Err = errors.Errorf("failed to download %s, error: %s", task.GID, info.ErrorMessage) + case "active": + s.Status = "aria2: " + info.Status + if info.Seeder == "true" { + s.Completed = true + } + case "waiting", "paused": + s.Status = "aria2: " + info.Status + case "removed": + s.Err = errors.Errorf("failed to download %s, removed", task.GID) + default: + return nil, errors.Errorf("[aria2] unknown status %s", info.Status) + } + return s, nil +} + +var _ tool.Tool = (*Aria2)(nil) + +func init() { + tool.Tools.Add(&Aria2{}) +} diff --git a/internal/offline_download/aria2/notify.go b/internal/offline_download/aria2/notify.go new file mode 100644 index 0000000000000000000000000000000000000000..056fe5147b464694228e132926970b3610b6681c --- /dev/null +++ b/internal/offline_download/aria2/notify.go @@ -0,0 +1,70 @@ +package aria2 + +import ( + "github.com/alist-org/alist/v3/pkg/aria2/rpc" + "github.com/alist-org/alist/v3/pkg/generic_sync" +) + +const ( + Downloading = iota + Paused + Stopped + Completed + Errored +) + +type Notify struct { + Signals generic_sync.MapOf[string, chan int] +} + +func NewNotify() *Notify { + return &Notify{Signals: generic_sync.MapOf[string, chan int]{}} +} + +func (n *Notify) OnDownloadStart(events []rpc.Event) { + for _, e := range events { + if signal, ok := n.Signals.Load(e.Gid); ok { + signal <- Downloading + } + } +} + +func (n *Notify) OnDownloadPause(events []rpc.Event) { + for _, e := range events { + if signal, ok := n.Signals.Load(e.Gid); ok { + signal <- Paused + } + } +} + +func (n *Notify) OnDownloadStop(events []rpc.Event) { + for _, e := range events { + if signal, ok := n.Signals.Load(e.Gid); ok { + signal <- Stopped + } + } +} + +func (n *Notify) OnDownloadComplete(events []rpc.Event) { + for _, e := range events { + if signal, ok := n.Signals.Load(e.Gid); ok { + signal <- Completed + } + } +} + +func (n *Notify) OnDownloadError(events []rpc.Event) { + for _, e := range events { + if signal, ok := n.Signals.Load(e.Gid); ok { + signal <- Errored + } + } +} + +func (n *Notify) OnBtDownloadComplete(events []rpc.Event) { + for _, e := range events { + if signal, ok := n.Signals.Load(e.Gid); ok { + signal <- Completed + } + } +} diff --git a/internal/offline_download/http/client.go b/internal/offline_download/http/client.go new file mode 100644 index 0000000000000000000000000000000000000000..0db05f35c15a98632875551fd30281d76241c721 --- /dev/null +++ b/internal/offline_download/http/client.go @@ -0,0 +1,85 @@ +package http + +import ( + "fmt" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/offline_download/tool" + "github.com/alist-org/alist/v3/pkg/utils" + "net/http" + "net/url" + "os" + "path" + "path/filepath" +) + +type SimpleHttp struct { + client http.Client +} + +func (s SimpleHttp) Name() string { + return "SimpleHttp" +} + +func (s SimpleHttp) Items() []model.SettingItem { + return nil +} + +func (s SimpleHttp) Init() (string, error) { + return "ok", nil +} + +func (s SimpleHttp) IsReady() bool { + return true +} + +func (s SimpleHttp) AddURL(args *tool.AddUrlArgs) (string, error) { + panic("should not be called") +} + +func (s SimpleHttp) Remove(task *tool.DownloadTask) error { + panic("should not be called") +} + +func (s SimpleHttp) Status(task *tool.DownloadTask) (*tool.Status, error) { + panic("should not be called") +} + +func (s SimpleHttp) Run(task *tool.DownloadTask) error { + u := task.Url + // parse url + _u, err := url.Parse(u) + if err != nil { + return err + } + req, err := http.NewRequestWithContext(task.Ctx(), http.MethodGet, u, nil) + if err != nil { + return err + } + resp, err := s.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return fmt.Errorf("http status code %d", resp.StatusCode) + } + filename := path.Base(_u.Path) + if n, err := parseFilenameFromContentDisposition(resp.Header.Get("Content-Disposition")); err == nil { + filename = n + } + // save to temp dir + _ = os.MkdirAll(task.TempDir, os.ModePerm) + filePath := filepath.Join(task.TempDir, filename) + file, err := os.Create(filePath) + if err != nil { + return err + } + defer file.Close() + fileSize := resp.ContentLength + err = utils.CopyWithCtx(task.Ctx(), file, resp.Body, fileSize, task.SetProgress) + return err +} + +func init() { + tool.Tools.Add(&SimpleHttp{}) +} diff --git a/internal/offline_download/http/util.go b/internal/offline_download/http/util.go new file mode 100644 index 0000000000000000000000000000000000000000..eefefec24ac154f4619f39aaa40193ed18d53946 --- /dev/null +++ b/internal/offline_download/http/util.go @@ -0,0 +1,21 @@ +package http + +import ( + "fmt" + "mime" +) + +func parseFilenameFromContentDisposition(contentDisposition string) (string, error) { + if contentDisposition == "" { + return "", fmt.Errorf("Content-Disposition is empty") + } + _, params, err := mime.ParseMediaType(contentDisposition) + if err != nil { + return "", err + } + filename := params["filename"] + if filename == "" { + return "", fmt.Errorf("filename not found in Content-Disposition: [%s]", contentDisposition) + } + return filename, nil +} diff --git a/internal/offline_download/qbit/qbit.go b/internal/offline_download/qbit/qbit.go new file mode 100644 index 0000000000000000000000000000000000000000..807ebfef2dc8811449afd6272aa1b14cf574227d --- /dev/null +++ b/internal/offline_download/qbit/qbit.go @@ -0,0 +1,85 @@ +package qbit + +import ( + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/offline_download/tool" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/qbittorrent" + "github.com/pkg/errors" +) + +type QBittorrent struct { + client qbittorrent.Client +} + +func (a *QBittorrent) Run(task *tool.DownloadTask) error { + return errs.NotSupport +} + +func (a *QBittorrent) Name() string { + return "qBittorrent" +} + +func (a *QBittorrent) Items() []model.SettingItem { + // qBittorrent settings + return []model.SettingItem{ + {Key: conf.QbittorrentUrl, Value: "http://admin:adminadmin@localhost:8080/", Type: conf.TypeString, Group: model.OFFLINE_DOWNLOAD, Flag: model.PRIVATE}, + {Key: conf.QbittorrentSeedtime, Value: "0", Type: conf.TypeNumber, Group: model.OFFLINE_DOWNLOAD, Flag: model.PRIVATE}, + } +} + +func (a *QBittorrent) Init() (string, error) { + a.client = nil + url := setting.GetStr(conf.QbittorrentUrl) + qbClient, err := qbittorrent.New(url) + if err != nil { + return "", err + } + a.client = qbClient + return "ok", nil +} + +func (a *QBittorrent) IsReady() bool { + return a.client != nil +} + +func (a *QBittorrent) AddURL(args *tool.AddUrlArgs) (string, error) { + err := a.client.AddFromLink(args.Url, args.TempDir, args.UID) + if err != nil { + return "", err + } + return args.UID, nil +} + +func (a *QBittorrent) Remove(task *tool.DownloadTask) error { + err := a.client.Delete(task.GID, false) + return err +} + +func (a *QBittorrent) Status(task *tool.DownloadTask) (*tool.Status, error) { + info, err := a.client.GetInfo(task.GID) + if err != nil { + return nil, err + } + s := &tool.Status{} + s.Progress = float64(info.Completed) / float64(info.Size) * 100 + switch info.State { + case qbittorrent.UPLOADING, qbittorrent.PAUSEDUP, qbittorrent.QUEUEDUP, qbittorrent.STALLEDUP, qbittorrent.FORCEDUP, qbittorrent.CHECKINGUP: + s.Completed = true + case qbittorrent.ALLOCATING, qbittorrent.DOWNLOADING, qbittorrent.METADL, qbittorrent.PAUSEDDL, qbittorrent.QUEUEDDL, qbittorrent.STALLEDDL, qbittorrent.CHECKINGDL, qbittorrent.FORCEDDL, qbittorrent.CHECKINGRESUMEDATA, qbittorrent.MOVING: + s.Status = "[qBittorrent] downloading" + case qbittorrent.ERROR, qbittorrent.MISSINGFILES, qbittorrent.UNKNOWN: + s.Err = errors.Errorf("[qBittorrent] failed to download %s, error: %s", task.GID, info.State) + default: + s.Err = errors.Errorf("[qBittorrent] unknown error occurred downloading %s", task.GID) + } + return s, nil +} + +var _ tool.Tool = (*QBittorrent)(nil) + +func init() { + tool.Tools.Add(&QBittorrent{}) +} diff --git a/internal/offline_download/storage/storage.go b/internal/offline_download/storage/storage.go new file mode 100644 index 0000000000000000000000000000000000000000..4692a217535e0576eb9aefc4e69a27c0e6dfa1c7 --- /dev/null +++ b/internal/offline_download/storage/storage.go @@ -0,0 +1,51 @@ +package storage + +import ( + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/offline_download/tool" + "github.com/pkg/errors" +) + +type Storage struct { +} + +func (a *Storage) Run(task *tool.DownloadTask) error { + return errs.NotSupport +} + +func (a *Storage) Name() string { + return "storage" +} + +func (a *Storage) Items() []model.SettingItem { + // qBittorrent settings + return []model.SettingItem{} +} + +func (a *Storage) Init() (string, error) { + return "ok", nil +} + +func (a *Storage) IsReady() bool { + return true +} + +func (a *Storage) AddURL(args *tool.AddUrlArgs) (string, error) { + return "ok", nil +} + +func (a *Storage) Remove(task *tool.DownloadTask) error { + return errors.Errorf("Failed to Remove") +} + +func (a *Storage) Status(task *tool.DownloadTask) (*tool.Status, error) { + s := &tool.Status{} + return s, nil +} + +var _ tool.Tool = (*Storage)(nil) + +func init() { + tool.Tools.Add(&Storage{}) +} diff --git a/internal/offline_download/tool/add.go b/internal/offline_download/tool/add.go new file mode 100644 index 0000000000000000000000000000000000000000..d9530036a9616e9e67c173549ea6ba82b498abc6 --- /dev/null +++ b/internal/offline_download/tool/add.go @@ -0,0 +1,92 @@ +package tool + +import ( + "context" + "path/filepath" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/tache" + "github.com/google/uuid" + "github.com/pkg/errors" +) + +type DeletePolicy string + +const ( + DeleteOnUploadSucceed DeletePolicy = "delete_on_upload_succeed" + DeleteOnUploadFailed DeletePolicy = "delete_on_upload_failed" + DeleteNever DeletePolicy = "delete_never" + DeleteAlways DeletePolicy = "delete_always" +) + +type AddURLArgs struct { + URL string + DstDirPath string + Tool string + DeletePolicy DeletePolicy +} + +func AddURL(ctx context.Context, args *AddURLArgs) (tache.TaskWithInfo, error) { + // get tool + tool, err := Tools.Get(args.Tool) + + if err != nil { + return nil, errors.Wrapf(err, "failed get tool") + } + // check tool is ready + if !tool.IsReady() { + // try to init tool + if _, err := tool.Init(); err != nil { + return nil, errors.Wrapf(err, "failed init tool %s", args.Tool) + } + } + // check storage + storage, dstDirActualPath, err := op.GetStorageAndActualPath(args.DstDirPath) + if err != nil { + return nil, errors.WithMessage(err, "failed get storage") + } + // check is it could upload + if storage.Config().NoUpload { + return nil, errors.WithStack(errs.UploadNotSupported) + } + // check path is valid + obj, err := op.Get(ctx, storage, dstDirActualPath) + if err != nil { + if !errs.IsObjectNotFound(err) { + return nil, errors.WithMessage(err, "failed get object") + } + } else { + if !obj.IsDir() { + // can't add to a file + return nil, errors.WithStack(errs.NotFolder) + } + } + + uid := uuid.NewString() + tempDir := filepath.Join(conf.Conf.TempDir, args.Tool, uid) + t := &DownloadTask{ + Url: args.URL, + DstDirPath: args.DstDirPath, + TempDir: tempDir, + DeletePolicy: args.DeletePolicy, + tool: tool, + } + if tool.Name() == "storage" { + args := model.FsOtherArgs{ + Path: args.DstDirPath, + Method: "offline", + Data: args.URL, + } + _, err := op.Offline(ctx, obj, storage, args) + if err != nil { + return nil, errors.WithMessage(err, "failed add task") + } + } else { + DownloadTaskManager.Add(t) + } + return t, nil + +} diff --git a/internal/offline_download/tool/all_test.go b/internal/offline_download/tool/all_test.go new file mode 100644 index 0000000000000000000000000000000000000000..27da5e32a89be12a64e832f47228a0ecaac6eb65 --- /dev/null +++ b/internal/offline_download/tool/all_test.go @@ -0,0 +1,17 @@ +package tool_test + +import ( + "testing" + + "github.com/alist-org/alist/v3/internal/offline_download/tool" +) + +func TestGetFiles(t *testing.T) { + files, err := tool.GetFiles("..") + if err != nil { + t.Fatal(err) + } + for _, file := range files { + t.Log(file.Name, file.Size, file.Path, file.Modified) + } +} diff --git a/internal/offline_download/tool/base.go b/internal/offline_download/tool/base.go new file mode 100644 index 0000000000000000000000000000000000000000..3b9fb07a999113861d01903687d005fb980c2015 --- /dev/null +++ b/internal/offline_download/tool/base.go @@ -0,0 +1,66 @@ +package tool + +import ( + "io" + "os" + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +type AddUrlArgs struct { + Url string + UID string + TempDir string + Signal chan int +} + +type Status struct { + Progress float64 + NewGID string + Completed bool + Status string + Err error +} + +type Tool interface { + Name() string + // Items return the setting items the tool need + Items() []model.SettingItem + Init() (string, error) + IsReady() bool + // AddURL add an uri to download, return the task id + AddURL(args *AddUrlArgs) (string, error) + // Remove the download if task been canceled + Remove(task *DownloadTask) error + // Status return the status of the download task, if an error occurred, return the error in Status.Err + Status(task *DownloadTask) (*Status, error) + + // Run for simple http download + Run(task *DownloadTask) error +} + +type GetFileser interface { + // GetFiles return the files of the download task, if nil, means walk the temp dir to get the files + GetFiles(task *DownloadTask) []File +} + +type File struct { + // ReadCloser for http client + ReadCloser io.ReadCloser + Name string + Size int64 + Path string + Modified time.Time +} + +func (f *File) GetReadCloser() (io.ReadCloser, error) { + if f.ReadCloser != nil { + return f.ReadCloser, nil + } + file, err := os.Open(f.Path) + if err != nil { + return nil, err + } + return file, nil +} diff --git a/internal/offline_download/tool/download.go b/internal/offline_download/tool/download.go new file mode 100644 index 0000000000000000000000000000000000000000..4731aaeb72db9c630fe95e3cb901fcf6f5f0cef4 --- /dev/null +++ b/internal/offline_download/tool/download.go @@ -0,0 +1,179 @@ +package tool + +import ( + "fmt" + "time" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/tache" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +type DownloadTask struct { + tache.Base + Name string `json:"name"` + Url string `json:"url"` + DstDirPath string `json:"dst_dir_path"` + TempDir string `json:"temp_dir"` + DeletePolicy DeletePolicy `json:"delete_policy"` + + Status string `json:"status"` + Signal chan int `json:"-"` + GID string `json:"-"` + tool Tool + callStatusRetried int +} + +func (t *DownloadTask) OnFailed() { + result := fmt.Sprintf("%s下载失败:%s", t.Url, t.GetErr()) + log.Debug(result) + if setting.GetBool(conf.NotifyEnabled) && setting.GetBool(conf.NotifyOnDownloadFailed) { + go op.Notify("文件下载结果", result) + } + +} + +func (t *DownloadTask) OnSucceeded() { + result := fmt.Sprintf("%s下载成功", t.Url) + log.Debug(result) + if setting.GetBool(conf.NotifyEnabled) && setting.GetBool(conf.NotifyOnDownloadSucceeded) { + go op.Notify("文件下载结果", result) + } +} + +func (t *DownloadTask) Run() error { + t.Name = fmt.Sprintf("download %s to (%s)", t.Url, t.DstDirPath) + if err := t.tool.Run(t); !errs.IsNotSupportError(err) { + if err == nil { + return t.Complete() + } + return err + } + t.Signal = make(chan int) + defer func() { + t.Signal = nil + }() + gid, err := t.tool.AddURL(&AddUrlArgs{ + Url: t.Url, + UID: t.ID, + TempDir: t.TempDir, + Signal: t.Signal, + }) + if err != nil { + return err + } + t.GID = gid + var ( + ok bool + ) +outer: + for { + select { + case <-t.CtxDone(): + err := t.tool.Remove(t) + return err + case <-t.Signal: + ok, err = t.Update() + if ok { + break outer + } + case <-time.After(time.Second * 3): + ok, err = t.Update() + if ok { + break outer + } + } + } + if err != nil { + return err + } + t.Status = "offline download completed, maybe transferring" + // hack for qBittorrent + if t.tool.Name() == "qBittorrent" { + seedTime := setting.GetInt(conf.QbittorrentSeedtime, 0) + if seedTime >= 0 { + t.Status = "offline download completed, waiting for seeding" + <-time.After(time.Minute * time.Duration(seedTime)) + err := t.tool.Remove(t) + if err != nil { + log.Errorln(err.Error()) + } + } + } + return nil +} + +// Update download status, return true if download completed +func (t *DownloadTask) Update() (bool, error) { + info, err := t.tool.Status(t) + if err != nil { + t.callStatusRetried++ + log.Errorf("failed to get status of %s, retried %d times", t.ID, t.callStatusRetried) + return false, nil + } + if t.callStatusRetried > 5 { + return true, errors.Errorf("failed to get status of %s, retried %d times", t.ID, t.callStatusRetried) + } + t.callStatusRetried = 0 + t.SetProgress(info.Progress) + t.Status = fmt.Sprintf("[%s]: %s", t.tool.Name(), info.Status) + if info.NewGID != "" { + log.Debugf("followen by: %+v", info.NewGID) + t.GID = info.NewGID + return false, nil + } + // if download completed + if info.Completed { + err := t.Complete() + return true, errors.WithMessage(err, "failed to transfer file") + } + // if download failed + if info.Err != nil { + return true, errors.Errorf("failed to download %s, error: %s", t.ID, info.Err.Error()) + } + return false, nil +} + +func (t *DownloadTask) Complete() error { + var ( + files []File + err error + ) + if getFileser, ok := t.tool.(GetFileser); ok { + files = getFileser.GetFiles(t) + } else { + files, err = GetFiles(t.TempDir) + if err != nil { + return errors.Wrapf(err, "failed to get files") + } + } + // upload files + for i, _ := range files { + file := files[i] + TransferTaskManager.Add(&TransferTask{ + file: file, + DstDirPath: t.DstDirPath, + TempDir: t.TempDir, + DeletePolicy: t.DeletePolicy, + FileDir: file.Path, + }) + } + return nil +} + +func (t *DownloadTask) GetName() string { + return t.Name + //return fmt.Sprintf("download %s to (%s)", t.Url, t.DstDirPath) +} + +func (t *DownloadTask) GetStatus() string { + return t.Status +} + +var ( + DownloadTaskManager *tache.Manager[*DownloadTask] +) diff --git a/internal/offline_download/tool/tools.go b/internal/offline_download/tool/tools.go new file mode 100644 index 0000000000000000000000000000000000000000..9de7d526ab00977f3f2410175a8611ab3df4bcf8 --- /dev/null +++ b/internal/offline_download/tool/tools.go @@ -0,0 +1,39 @@ +package tool + +import ( + "fmt" + "github.com/alist-org/alist/v3/internal/model" +) + +var ( + Tools = make(ToolsManager) +) + +type ToolsManager map[string]Tool + +func (t ToolsManager) Get(name string) (Tool, error) { + if tool, ok := t[name]; ok { + return tool, nil + } + return nil, fmt.Errorf("tool %s not found", name) +} + +func (t ToolsManager) Add(tool Tool) { + t[tool.Name()] = tool +} + +func (t ToolsManager) Names() []string { + names := make([]string, 0, len(t)) + for name := range t { + names = append(names, name) + } + return names +} + +func (t ToolsManager) Items() []model.SettingItem { + var items []model.SettingItem + for _, tool := range t { + items = append(items, tool.Items()...) + } + return items +} diff --git a/internal/offline_download/tool/transfer.go b/internal/offline_download/tool/transfer.go new file mode 100644 index 0000000000000000000000000000000000000000..31a4981628bac4446c191f34e27ab734d741303a --- /dev/null +++ b/internal/offline_download/tool/transfer.go @@ -0,0 +1,93 @@ +package tool + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/pkg/tache" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +type TransferTask struct { + tache.Base + file File + FileDir string `json:"file_dir"` + DstDirPath string `json:"dst_dir_path"` + TempDir string `json:"temp_dir"` + DeletePolicy DeletePolicy `json:"delete_policy"` +} + +func (t *TransferTask) Run() error { + // check dstDir again + var err error + if (t.file == File{}) { + t.file, err = GetFile(t.FileDir) + if err != nil { + return errors.Wrapf(err, "failed to get file %s", t.FileDir) + } + } + + storage, dstDirActualPath, err := op.GetStorageAndActualPath(t.DstDirPath) + if err != nil { + return errors.WithMessage(err, "failed get storage") + } + mimetype := utils.GetMimeType(t.file.Path) + rc, err := t.file.GetReadCloser() + if err != nil { + return errors.Wrapf(err, "failed to open file %s", t.file.Path) + } + s := &stream.FileStream{ + Ctx: nil, + Obj: &model.Object{ + Name: filepath.Base(t.file.Path), + Size: t.file.Size, + Modified: t.file.Modified, + IsFolder: false, + }, + Reader: rc, + Mimetype: mimetype, + Closers: utils.NewClosers(rc), + } + relDir, err := filepath.Rel(t.TempDir, filepath.Dir(t.file.Path)) + if err != nil { + log.Errorf("find relation directory error: %v", err) + } + newDistDir := filepath.Join(dstDirActualPath, relDir) + return op.Put(t.Ctx(), storage, newDistDir, s, t.SetProgress) +} + +func (t *TransferTask) GetName() string { + return fmt.Sprintf("transfer %s to [%s]", t.file.Path, t.DstDirPath) +} + +func (t *TransferTask) GetStatus() string { + return "transferring" +} + +func (t *TransferTask) OnSucceeded() { + if t.DeletePolicy == DeleteOnUploadSucceed || t.DeletePolicy == DeleteAlways { + err := os.Remove(t.file.Path) + if err != nil { + log.Errorf("failed to delete file %s, error: %s", t.file.Path, err.Error()) + } + } +} + +func (t *TransferTask) OnFailed() { + if t.DeletePolicy == DeleteOnUploadFailed || t.DeletePolicy == DeleteAlways { + err := os.Remove(t.file.Path) + if err != nil { + log.Errorf("failed to delete file %s, error: %s", t.file.Path, err.Error()) + } + } +} + +var ( + TransferTaskManager *tache.Manager[*TransferTask] +) diff --git a/internal/offline_download/tool/util.go b/internal/offline_download/tool/util.go new file mode 100644 index 0000000000000000000000000000000000000000..b2c6ec02bfa1c2d2ecffff4748361c125cf5ec26 --- /dev/null +++ b/internal/offline_download/tool/util.go @@ -0,0 +1,41 @@ +package tool + +import ( + "os" + "path/filepath" +) + +func GetFiles(dir string) ([]File, error) { + var files []File + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if !info.IsDir() { + files = append(files, File{ + Name: info.Name(), + Size: info.Size(), + Path: path, + Modified: info.ModTime(), + }) + } + return nil + }) + if err != nil { + return nil, err + } + return files, nil +} + +func GetFile(path string) (File, error) { + info, err := os.Stat(path) + if err != nil { + return File{}, err + } + return File{ + Name: info.Name(), + Size: info.Size(), + Path: path, + Modified: info.ModTime(), + }, nil +} diff --git a/internal/op/const.go b/internal/op/const.go new file mode 100644 index 0000000000000000000000000000000000000000..0b4498c8c0ef1fa59ab9ddd362bc4d2744ab734a --- /dev/null +++ b/internal/op/const.go @@ -0,0 +1,7 @@ +package op + +const ( + WORK = "work" + DISABLED = "disabled" + RootName = "root" +) diff --git a/internal/op/driver.go b/internal/op/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..b2cace865ec8ff28b50d7459d3f8cffc4778b5ba --- /dev/null +++ b/internal/op/driver.go @@ -0,0 +1,179 @@ +package op + +import ( + "reflect" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/pkg/errors" +) + +type DriverConstructor func() driver.Driver + +var driverMap = map[string]DriverConstructor{} +var driverInfoMap = map[string]driver.Info{} + +func RegisterDriver(driver DriverConstructor) { + // log.Infof("register driver: [%s]", config.Name) + tempDriver := driver() + tempConfig := tempDriver.Config() + registerDriverItems(tempConfig, tempDriver.GetAddition()) + driverMap[tempConfig.Name] = driver +} + +func GetDriver(name string) (DriverConstructor, error) { + n, ok := driverMap[name] + if !ok { + return nil, errors.Errorf("no driver named: %s", name) + } + return n, nil +} + +func GetDriverNames() []string { + var driverNames []string + for k := range driverInfoMap { + driverNames = append(driverNames, k) + } + return driverNames +} + +func GetDriverInfoMap() map[string]driver.Info { + return driverInfoMap +} + +func registerDriverItems(config driver.Config, addition driver.Additional) { + // log.Debugf("addition of %s: %+v", config.Name, addition) + tAddition := reflect.TypeOf(addition) + for tAddition.Kind() == reflect.Pointer { + tAddition = tAddition.Elem() + } + mainItems := getMainItems(config) + additionalItems := getAdditionalItems(tAddition, config.DefaultRoot) + driverInfoMap[config.Name] = driver.Info{ + Common: mainItems, + Additional: additionalItems, + Config: config, + } +} + +func getMainItems(config driver.Config) []driver.Item { + items := []driver.Item{{ + Name: "mount_path", + Type: conf.TypeString, + Required: true, + Help: "The path you want to mount to, it is unique and cannot be repeated", + }, { + Name: "order", + Type: conf.TypeNumber, + Help: "use to sort", + }, { + Name: "remark", + Type: conf.TypeText, + }, { + Name: "group", + Type: conf.TypeString, + Default: "未分组", + }, { + Name: "sync_group", + Type: conf.TypeBool, + Default: "false", + Help: "同时修改同组内所有存储", + }} + if !config.NoCache { + items = append(items, driver.Item{ + Name: "cache_expiration", + Type: conf.TypeNumber, + Default: "30", + Required: true, + Help: "The cache expiration time for this storage", + }) + } + if !config.OnlyProxy && !config.OnlyLocal { + items = append(items, []driver.Item{{ + Name: "web_proxy", + Type: conf.TypeBool, + }, { + Name: "webdav_policy", + Type: conf.TypeSelect, + Options: "302_redirect,use_proxy_url,native_proxy", + Default: "302_redirect", + Required: true, + }, + }...) + } else { + items = append(items, driver.Item{ + Name: "webdav_policy", + Type: conf.TypeSelect, + Default: "native_proxy", + Options: "use_proxy_url,native_proxy", + Required: true, + }) + } + items = append(items, driver.Item{ + Name: "down_proxy_url", + Type: conf.TypeText, + }) + if config.LocalSort { + items = append(items, []driver.Item{{ + Name: "order_by", + Type: conf.TypeSelect, + Options: "name,size,modified", + }, { + Name: "order_direction", + Type: conf.TypeSelect, + Options: "asc,desc", + }}...) + } + items = append(items, driver.Item{ + Name: "extract_folder", + Type: conf.TypeSelect, + Options: "front,back", + }) + items = append(items, driver.Item{ + Name: "enable_sign", + Type: conf.TypeBool, + Default: "false", + Required: true, + }) + return items +} +func getAdditionalItems(t reflect.Type, defaultRoot string) []driver.Item { + var items []driver.Item + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.Type.Kind() == reflect.Struct { + items = append(items, getAdditionalItems(field.Type, defaultRoot)...) + continue + } + tag := field.Tag + ignore, ok1 := tag.Lookup("ignore") + name, ok2 := tag.Lookup("json") + if (ok1 && ignore == "true") || !ok2 { + continue + } + item := driver.Item{ + Name: name, + Type: strings.ToLower(field.Type.Name()), + Default: tag.Get("default"), + Options: tag.Get("options"), + Required: tag.Get("required") == "true", + Help: tag.Get("help"), + } + if tag.Get("type") != "" { + item.Type = tag.Get("type") + } + if item.Name == "root_folder_id" || item.Name == "root_folder_path" { + if item.Default == "" { + item.Default = defaultRoot + } + item.Required = item.Default != "" + } + // set default type to string + if item.Type == "" { + item.Type = "string" + } + items = append(items, item) + } + return items +} diff --git a/internal/op/driver_test.go b/internal/op/driver_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1c07c9f859b103d28cc527cc9bb4e1919b26ce2d --- /dev/null +++ b/internal/op/driver_test.go @@ -0,0 +1,17 @@ +package op_test + +import ( + "testing" + + _ "github.com/alist-org/alist/v3/drivers" + "github.com/alist-org/alist/v3/internal/op" +) + +func TestDriverItemsMap(t *testing.T) { + itemsMap := op.GetDriverInfoMap() + if len(itemsMap) != 0 { + t.Logf("driverInfoMap: %v", itemsMap) + } else { + t.Errorf("expected driverInfoMap not empty, but got empty") + } +} diff --git a/internal/op/fs.go b/internal/op/fs.go new file mode 100644 index 0000000000000000000000000000000000000000..b815237ca506091da88a802ddbb65a1c0618cd8b --- /dev/null +++ b/internal/op/fs.go @@ -0,0 +1,595 @@ +package op + +import ( + "context" + stdpath "path" + "time" + + "github.com/Xhofe/go-cache" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/generic_sync" + "github.com/alist-org/alist/v3/pkg/singleflight" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +// In order to facilitate adding some other things before and after file op + +var listCache = cache.NewMemCache(cache.WithShards[[]model.Obj](64)) +var listG singleflight.Group[[]model.Obj] + +func updateCacheObj(storage driver.Driver, path string, oldObj model.Obj, newObj model.Obj) { + key := Key(storage, path) + objs, ok := listCache.Get(key) + if ok { + for i, obj := range objs { + if obj.GetName() == oldObj.GetName() { + objs[i] = newObj + break + } + } + listCache.Set(key, objs, cache.WithEx[[]model.Obj](time.Minute*time.Duration(storage.GetStorage().CacheExpiration))) + } +} + +func delCacheObj(storage driver.Driver, path string, obj model.Obj) { + key := Key(storage, path) + objs, ok := listCache.Get(key) + if ok { + for i, oldObj := range objs { + if oldObj.GetName() == obj.GetName() { + objs = append(objs[:i], objs[i+1:]...) + break + } + } + listCache.Set(key, objs, cache.WithEx[[]model.Obj](time.Minute*time.Duration(storage.GetStorage().CacheExpiration))) + } +} + +var addSortDebounceMap generic_sync.MapOf[string, func(func())] + +func addCacheObj(storage driver.Driver, path string, newObj model.Obj) { + key := Key(storage, path) + objs, ok := listCache.Get(key) + if ok { + for i, obj := range objs { + if obj.GetName() == newObj.GetName() { + objs[i] = newObj + return + } + } + + // Simple separation of files and folders + if len(objs) > 0 && objs[len(objs)-1].IsDir() == newObj.IsDir() { + objs = append(objs, newObj) + } else { + objs = append([]model.Obj{newObj}, objs...) + } + + if storage.Config().LocalSort { + debounce, _ := addSortDebounceMap.LoadOrStore(key, utils.NewDebounce(time.Minute)) + log.Debug("addCacheObj: wait start sort") + debounce(func() { + log.Debug("addCacheObj: start sort") + model.SortFiles(objs, storage.GetStorage().OrderBy, storage.GetStorage().OrderDirection) + addSortDebounceMap.Delete(key) + }) + } + + listCache.Set(key, objs, cache.WithEx[[]model.Obj](time.Minute*time.Duration(storage.GetStorage().CacheExpiration))) + } +} + +func ClearCache(storage driver.Driver, path string) { + objs, ok := listCache.Get(Key(storage, path)) + if ok { + for _, obj := range objs { + if obj.IsDir() { + ClearCache(storage, stdpath.Join(path, obj.GetName())) + } + } + } + listCache.Del(Key(storage, path)) +} + +func Key(storage driver.Driver, path string) string { + return stdpath.Join(storage.GetStorage().MountPath, utils.FixAndCleanPath(path)) +} + +// List files in storage, not contains virtual file +func List(ctx context.Context, storage driver.Driver, path string, args model.ListArgs, refresh ...bool) ([]model.Obj, error) { + if storage.Config().CheckStatus && storage.GetStorage().Status != WORK { + return nil, errors.Errorf("storage not init: %s", storage.GetStorage().Status) + } + path = utils.FixAndCleanPath(path) + log.Debugf("op.List %s", path) + key := Key(storage, path) + if !utils.IsBool(refresh...) { + if files, ok := listCache.Get(key); ok { + log.Debugf("use cache when list %s", path) + return files, nil + } + } + dir, err := GetUnwrap(ctx, storage, path) + if err != nil { + return nil, errors.WithMessage(err, "failed get dir") + } + log.Debugf("list dir: %+v", dir) + if !dir.IsDir() { + return nil, errors.WithStack(errs.NotFolder) + } + objs, err, _ := listG.Do(key, func() ([]model.Obj, error) { + files, err := storage.List(ctx, dir, args) + if err != nil { + return nil, errors.Wrapf(err, "failed to list objs") + } + // set path + for _, f := range files { + if s, ok := f.(model.SetPath); ok && f.GetPath() == "" && dir.GetPath() != "" { + s.SetPath(stdpath.Join(dir.GetPath(), f.GetName())) + } + } + // warp obj name + model.WrapObjsName(files) + // call hooks + go func(reqPath string, files []model.Obj) { + for _, hook := range objsUpdateHooks { + hook(reqPath, files) + } + }(utils.GetFullPath(storage.GetStorage().MountPath, path), files) + + // sort objs + if storage.Config().LocalSort { + model.SortFiles(files, storage.GetStorage().OrderBy, storage.GetStorage().OrderDirection) + } + model.ExtractFolder(files, storage.GetStorage().ExtractFolder) + + if !storage.Config().NoCache { + if len(files) > 0 { + log.Debugf("set cache: %s => %+v", key, files) + listCache.Set(key, files, cache.WithEx[[]model.Obj](time.Minute*time.Duration(storage.GetStorage().CacheExpiration))) + } else { + log.Debugf("del cache: %s", key) + listCache.Del(key) + } + } + return files, nil + }) + return objs, err +} + +// Get object from list of files +func Get(ctx context.Context, storage driver.Driver, path string) (model.Obj, error) { + path = utils.FixAndCleanPath(path) + log.Debugf("op.Get %s", path) + + // get the obj directly without list so that we can reduce the io + if g, ok := storage.(driver.Getter); ok { + obj, err := g.Get(ctx, path) + if err == nil { + return model.WrapObjName(obj), nil + } + } + + // is root folder + if utils.PathEqual(path, "/") { + var rootObj model.Obj + if getRooter, ok := storage.(driver.GetRooter); ok { + obj, err := getRooter.GetRoot(ctx) + if err != nil { + return nil, errors.WithMessage(err, "failed get root obj") + } + rootObj = obj + } else { + switch r := storage.GetAddition().(type) { + case driver.IRootId: + rootObj = &model.Object{ + ID: r.GetRootId(), + Name: RootName, + Size: 0, + Modified: storage.GetStorage().Modified, + IsFolder: true, + } + case driver.IRootPath: + rootObj = &model.Object{ + Path: r.GetRootPath(), + Name: RootName, + Size: 0, + Modified: storage.GetStorage().Modified, + IsFolder: true, + } + default: + return nil, errors.Errorf("please implement IRootPath or IRootId or GetRooter method") + } + } + if rootObj == nil { + return nil, errors.Errorf("please implement IRootPath or IRootId or GetRooter method") + } + return &model.ObjWrapName{ + Name: RootName, + Obj: rootObj, + }, nil + } + + // not root folder + dir, name := stdpath.Split(path) + files, err := List(ctx, storage, dir, model.ListArgs{}) + if err != nil { + return nil, errors.WithMessage(err, "failed get parent list") + } + for _, f := range files { + if f.GetName() == name { + return f, nil + } + } + log.Debugf("cant find obj with name: %s", name) + return nil, errors.WithStack(errs.ObjectNotFound) +} + +func GetUnwrap(ctx context.Context, storage driver.Driver, path string) (model.Obj, error) { + obj, err := Get(ctx, storage, path) + if err != nil { + return nil, err + } + return model.UnwrapObj(obj), err +} + +var linkCache = cache.NewMemCache(cache.WithShards[*model.Link](16)) +var linkG singleflight.Group[*model.Link] + +// Link get link, if is an url. should have an expiry time +func Link(ctx context.Context, storage driver.Driver, path string, args model.LinkArgs) (*model.Link, model.Obj, error) { + if storage.Config().CheckStatus && storage.GetStorage().Status != WORK { + return nil, nil, errors.Errorf("storage not init: %s", storage.GetStorage().Status) + } + file, err := GetUnwrap(ctx, storage, path) + if err != nil { + return nil, nil, errors.WithMessage(err, "failed to get file") + } + if file.IsDir() { + return nil, nil, errors.WithStack(errs.NotFile) + } + key := Key(storage, path) + if link, ok := linkCache.Get(key); ok { + return link, file, nil + } + fn := func() (*model.Link, error) { + link, err := storage.Link(ctx, file, args) + if err != nil { + return nil, errors.Wrapf(err, "failed get link") + } + if link.Expiration != nil { + if link.IPCacheKey { + key = key + ":" + args.IP + } + linkCache.Set(key, link, cache.WithEx[*model.Link](*link.Expiration)) + } + return link, nil + } + link, err, _ := linkG.Do(key, fn) + return link, file, err +} + +// Other api +func Other(ctx context.Context, storage driver.Driver, args model.FsOtherArgs) (interface{}, error) { + obj, err := GetUnwrap(ctx, storage, args.Path) + if err != nil { + return nil, errors.WithMessagef(err, "failed to get obj") + } + if o, ok := storage.(driver.Other); ok { + return o.Other(ctx, model.OtherArgs{ + Obj: obj, + Method: args.Method, + Data: args.Data, + }) + } else { + return nil, errs.NotImplement + } +} + +// Offline api +func Offline(ctx context.Context, obj model.Obj, storage driver.Driver, args model.FsOtherArgs) (interface{}, error) { + if o, ok := storage.(driver.Offline); ok { + return o.Offline(ctx, model.OtherArgs{ + Obj: obj, + Method: args.Method, + Data: args.Data, + }) + } else { + return nil, errs.NotImplement + } + +} + +var mkdirG singleflight.Group[interface{}] + +func MakeDir(ctx context.Context, storage driver.Driver, path string, lazyCache ...bool) error { + if storage.Config().CheckStatus && storage.GetStorage().Status != WORK { + return errors.Errorf("storage not init: %s", storage.GetStorage().Status) + } + path = utils.FixAndCleanPath(path) + key := Key(storage, path) + _, err, _ := mkdirG.Do(key, func() (interface{}, error) { + // check if dir exists + f, err := GetUnwrap(ctx, storage, path) + if err != nil { + if errs.IsObjectNotFound(err) { + parentPath, dirName := stdpath.Split(path) + err = MakeDir(ctx, storage, parentPath) + if err != nil { + return nil, errors.WithMessagef(err, "failed to make parent dir [%s]", parentPath) + } + parentDir, err := GetUnwrap(ctx, storage, parentPath) + // this should not happen + if err != nil { + return nil, errors.WithMessagef(err, "failed to get parent dir [%s]", parentPath) + } + + switch s := storage.(type) { + case driver.MkdirResult: + var newObj model.Obj + newObj, err = s.MakeDir(ctx, parentDir, dirName) + if err == nil { + if newObj != nil { + addCacheObj(storage, parentPath, model.WrapObjName(newObj)) + } else if !utils.IsBool(lazyCache...) { + ClearCache(storage, parentPath) + } + } + case driver.Mkdir: + err = s.MakeDir(ctx, parentDir, dirName) + if err == nil && !utils.IsBool(lazyCache...) { + ClearCache(storage, parentPath) + } + default: + return nil, errs.NotImplement + } + return nil, errors.WithStack(err) + } + return nil, errors.WithMessage(err, "failed to check if dir exists") + } + // dir exists + if f.IsDir() { + return nil, nil + } + // dir to make is a file + return nil, errors.New("file exists") + }) + return err +} + +func Move(ctx context.Context, storage driver.Driver, srcPath, dstDirPath string, lazyCache ...bool) error { + if storage.Config().CheckStatus && storage.GetStorage().Status != WORK { + return errors.Errorf("storage not init: %s", storage.GetStorage().Status) + } + srcPath = utils.FixAndCleanPath(srcPath) + dstDirPath = utils.FixAndCleanPath(dstDirPath) + srcRawObj, err := Get(ctx, storage, srcPath) + if err != nil { + return errors.WithMessage(err, "failed to get src object") + } + srcObj := model.UnwrapObj(srcRawObj) + dstDir, err := GetUnwrap(ctx, storage, dstDirPath) + if err != nil { + return errors.WithMessage(err, "failed to get dst dir") + } + srcDirPath := stdpath.Dir(srcPath) + + switch s := storage.(type) { + case driver.MoveResult: + var newObj model.Obj + newObj, err = s.Move(ctx, srcObj, dstDir) + if err == nil { + delCacheObj(storage, srcDirPath, srcRawObj) + if newObj != nil { + addCacheObj(storage, dstDirPath, model.WrapObjName(newObj)) + } else if !utils.IsBool(lazyCache...) { + ClearCache(storage, dstDirPath) + } + } + case driver.Move: + err = s.Move(ctx, srcObj, dstDir) + if err == nil { + delCacheObj(storage, srcDirPath, srcRawObj) + if !utils.IsBool(lazyCache...) { + ClearCache(storage, dstDirPath) + } + } + default: + return errs.NotImplement + } + return errors.WithStack(err) +} + +func Rename(ctx context.Context, storage driver.Driver, srcPath, dstName string, lazyCache ...bool) error { + if storage.Config().CheckStatus && storage.GetStorage().Status != WORK { + return errors.Errorf("storage not init: %s", storage.GetStorage().Status) + } + srcPath = utils.FixAndCleanPath(srcPath) + srcRawObj, err := Get(ctx, storage, srcPath) + if err != nil { + return errors.WithMessage(err, "failed to get src object") + } + srcObj := model.UnwrapObj(srcRawObj) + srcDirPath := stdpath.Dir(srcPath) + + switch s := storage.(type) { + case driver.RenameResult: + var newObj model.Obj + newObj, err = s.Rename(ctx, srcObj, dstName) + if err == nil { + if newObj != nil { + updateCacheObj(storage, srcDirPath, srcRawObj, model.WrapObjName(newObj)) + } else if !utils.IsBool(lazyCache...) { + ClearCache(storage, srcDirPath) + } + } + case driver.Rename: + err = s.Rename(ctx, srcObj, dstName) + if err == nil && !utils.IsBool(lazyCache...) { + ClearCache(storage, srcDirPath) + } + default: + return errs.NotImplement + } + return errors.WithStack(err) +} + +// Copy Just copy file[s] in a storage +func Copy(ctx context.Context, storage driver.Driver, srcPath, dstDirPath string, lazyCache ...bool) error { + if storage.Config().CheckStatus && storage.GetStorage().Status != WORK { + return errors.Errorf("storage not init: %s", storage.GetStorage().Status) + } + srcPath = utils.FixAndCleanPath(srcPath) + dstDirPath = utils.FixAndCleanPath(dstDirPath) + srcObj, err := GetUnwrap(ctx, storage, srcPath) + if err != nil { + return errors.WithMessage(err, "failed to get src object") + } + dstDir, err := GetUnwrap(ctx, storage, dstDirPath) + if err != nil { + return errors.WithMessage(err, "failed to get dst dir") + } + + switch s := storage.(type) { + case driver.CopyResult: + var newObj model.Obj + newObj, err = s.Copy(ctx, srcObj, dstDir) + if err == nil { + if newObj != nil { + addCacheObj(storage, dstDirPath, model.WrapObjName(newObj)) + } else if !utils.IsBool(lazyCache...) { + ClearCache(storage, dstDirPath) + } + } + case driver.Copy: + err = s.Copy(ctx, srcObj, dstDir) + if err == nil && !utils.IsBool(lazyCache...) { + ClearCache(storage, dstDirPath) + } + default: + return errs.NotImplement + } + return errors.WithStack(err) +} + +func Remove(ctx context.Context, storage driver.Driver, path string) error { + if storage.Config().CheckStatus && storage.GetStorage().Status != WORK { + return errors.Errorf("storage not init: %s", storage.GetStorage().Status) + } + path = utils.FixAndCleanPath(path) + rawObj, err := Get(ctx, storage, path) + if err != nil { + // if object not found, it's ok + if errs.IsObjectNotFound(err) { + log.Debugf("%s have been removed", path) + return nil + } + return errors.WithMessage(err, "failed to get object") + } + dirPath := stdpath.Dir(path) + + switch s := storage.(type) { + case driver.Remove: + err = s.Remove(ctx, model.UnwrapObj(rawObj)) + if err == nil { + delCacheObj(storage, dirPath, rawObj) + // clear folder cache recursively + if rawObj.IsDir() { + ClearCache(storage, path) + } + } + default: + return errs.NotImplement + } + return errors.WithStack(err) +} + +func Put(ctx context.Context, storage driver.Driver, dstDirPath string, file model.FileStreamer, up driver.UpdateProgress, lazyCache ...bool) error { + if storage.Config().CheckStatus && storage.GetStorage().Status != WORK { + return errors.Errorf("storage not init: %s", storage.GetStorage().Status) + } + defer func() { + if err := file.Close(); err != nil { + log.Errorf("failed to close file streamer, %v", err) + } + }() + // if file exist and size = 0, delete it + dstDirPath = utils.FixAndCleanPath(dstDirPath) + dstPath := stdpath.Join(dstDirPath, file.GetName()) + tempName := file.GetName() + ".alist_to_delete" + tempPath := stdpath.Join(dstDirPath, tempName) + fi, err := GetUnwrap(ctx, storage, dstPath) + if err == nil { + if fi.GetSize() == 0 { + err = Remove(ctx, storage, dstPath) + if err != nil { + return errors.WithMessagef(err, "while uploading, failed remove existing file which size = 0") + } + } else if storage.Config().NoOverwriteUpload { + // try to rename old obj + err = Rename(ctx, storage, dstPath, tempName) + if err != nil { + return err + } + } else { + file.SetExist(fi) + } + } + err = MakeDir(ctx, storage, dstDirPath) + if err != nil { + return errors.WithMessagef(err, "failed to make dir [%s]", dstDirPath) + } + parentDir, err := GetUnwrap(ctx, storage, dstDirPath) + // this should not happen + if err != nil { + return errors.WithMessagef(err, "failed to get dir [%s]", dstDirPath) + } + // if up is nil, set a default to prevent panic + if up == nil { + up = func(p float64) {} + } + + switch s := storage.(type) { + case driver.PutResult: + var newObj model.Obj + newObj, err = s.Put(ctx, parentDir, file, up) + if err == nil { + if newObj != nil { + addCacheObj(storage, dstDirPath, model.WrapObjName(newObj)) + } else if !utils.IsBool(lazyCache...) { + ClearCache(storage, dstDirPath) + } + } + case driver.Put: + err = s.Put(ctx, parentDir, file, up) + if err == nil && !utils.IsBool(lazyCache...) { + ClearCache(storage, dstDirPath) + } + default: + return errs.NotImplement + } + log.Debugf("put file [%s] done", file.GetName()) + if storage.Config().NoOverwriteUpload && fi != nil && fi.GetSize() > 0 { + if err != nil { + // upload failed, recover old obj + err := Rename(ctx, storage, tempPath, file.GetName()) + if err != nil { + log.Errorf("failed recover old obj: %+v", err) + } + } else { + // upload success, remove old obj + err := Remove(ctx, storage, tempPath) + if err != nil { + return err + } else { + key := Key(storage, stdpath.Join(dstDirPath, file.GetName())) + linkCache.Del(key) + } + } + } + return errors.WithStack(err) +} diff --git a/internal/op/hook.go b/internal/op/hook.go new file mode 100644 index 0000000000000000000000000000000000000000..67c4978b7147c7d2b7aede41fbebfd42fa2331cf --- /dev/null +++ b/internal/op/hook.go @@ -0,0 +1,115 @@ +package op + +import ( + "regexp" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +// Obj +type ObjsUpdateHook = func(parent string, objs []model.Obj) + +var ( + objsUpdateHooks = make([]ObjsUpdateHook, 0) +) + +func RegisterObjsUpdateHook(hook ObjsUpdateHook) { + objsUpdateHooks = append(objsUpdateHooks, hook) +} + +func HandleObjsUpdateHook(parent string, objs []model.Obj) { + for _, hook := range objsUpdateHooks { + hook(parent, objs) + } +} + +// Setting +type SettingItemHook func(item *model.SettingItem) error + +var settingItemHooks = map[string]SettingItemHook{ + conf.VideoTypes: func(item *model.SettingItem) error { + conf.SlicesMap[conf.VideoTypes] = strings.Split(item.Value, ",") + return nil + }, + conf.AudioTypes: func(item *model.SettingItem) error { + conf.SlicesMap[conf.AudioTypes] = strings.Split(item.Value, ",") + return nil + }, + conf.ImageTypes: func(item *model.SettingItem) error { + conf.SlicesMap[conf.ImageTypes] = strings.Split(item.Value, ",") + return nil + }, + conf.TextTypes: func(item *model.SettingItem) error { + conf.SlicesMap[conf.TextTypes] = strings.Split(item.Value, ",") + return nil + }, + conf.ProxyTypes: func(item *model.SettingItem) error { + conf.SlicesMap[conf.ProxyTypes] = strings.Split(item.Value, ",") + return nil + }, + conf.ProxyIgnoreHeaders: func(item *model.SettingItem) error { + conf.SlicesMap[conf.ProxyIgnoreHeaders] = strings.Split(item.Value, ",") + return nil + }, + conf.PrivacyRegs: func(item *model.SettingItem) error { + regStrs := strings.Split(item.Value, "\n") + regs := make([]*regexp.Regexp, 0, len(regStrs)) + for _, regStr := range regStrs { + reg, err := regexp.Compile(regStr) + if err != nil { + return errors.WithStack(err) + } + regs = append(regs, reg) + } + conf.PrivacyReg = regs + return nil + }, + conf.FilenameCharMapping: func(item *model.SettingItem) error { + err := utils.Json.UnmarshalFromString(item.Value, &conf.FilenameCharMap) + if err != nil { + return err + } + log.Debugf("filename char mapping: %+v", conf.FilenameCharMap) + return nil + }, + conf.IgnoreDirectLinkParams: func(item *model.SettingItem) error { + conf.SlicesMap[conf.IgnoreDirectLinkParams] = strings.Split(item.Value, ",") + return nil + }, + conf.StorageGroups: func(item *model.SettingItem) error { + conf.SlicesMap[conf.StorageGroups] = strings.Split(item.Value, ",") + return nil + }, +} + +func RegisterSettingItemHook(key string, hook SettingItemHook) { + settingItemHooks[key] = hook +} + +func HandleSettingItemHook(item *model.SettingItem) (hasHook bool, err error) { + if hook, ok := settingItemHooks[item.Key]; ok { + return true, hook(item) + } + return false, nil +} + +// Storage +type StorageHook func(typ string, storage driver.Driver) + +var storageHooks = make([]StorageHook, 0) + +func callStorageHooks(typ string, storage driver.Driver) { + for _, hook := range storageHooks { + hook(typ, storage) + } +} + +func RegisterStorageHook(hook StorageHook) { + storageHooks = append(storageHooks, hook) +} diff --git a/internal/op/meta.go b/internal/op/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..930f49634c35b43e600c41a3a30be83a006d3a74 --- /dev/null +++ b/internal/op/meta.go @@ -0,0 +1,96 @@ +package op + +import ( + stdpath "path" + "time" + + "github.com/Xhofe/go-cache" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/singleflight" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" + "gorm.io/gorm" +) + +var metaCache = cache.NewMemCache(cache.WithShards[*model.Meta](2)) + +// metaG maybe not needed +var metaG singleflight.Group[*model.Meta] + +func GetNearestMeta(path string) (*model.Meta, error) { + return getNearestMeta(utils.FixAndCleanPath(path)) +} +func getNearestMeta(path string) (*model.Meta, error) { + meta, err := GetMetaByPath(path) + if err == nil { + return meta, nil + } + if errors.Cause(err) != errs.MetaNotFound { + return nil, err + } + if path == "/" { + return nil, errs.MetaNotFound + } + return getNearestMeta(stdpath.Dir(path)) +} + +func GetMetaByPath(path string) (*model.Meta, error) { + return getMetaByPath(utils.FixAndCleanPath(path)) +} +func getMetaByPath(path string) (*model.Meta, error) { + meta, ok := metaCache.Get(path) + if ok { + if meta == nil { + return meta, errs.MetaNotFound + } + return meta, nil + } + meta, err, _ := metaG.Do(path, func() (*model.Meta, error) { + _meta, err := db.GetMetaByPath(path) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + metaCache.Set(path, nil) + return nil, errs.MetaNotFound + } + return nil, err + } + metaCache.Set(path, _meta, cache.WithEx[*model.Meta](time.Hour)) + return _meta, nil + }) + return meta, err +} + +func DeleteMetaById(id uint) error { + old, err := db.GetMetaById(id) + if err != nil { + return err + } + metaCache.Del(old.Path) + return db.DeleteMetaById(id) +} + +func UpdateMeta(u *model.Meta) error { + u.Path = utils.FixAndCleanPath(u.Path) + old, err := db.GetMetaById(u.ID) + if err != nil { + return err + } + metaCache.Del(old.Path) + return db.UpdateMeta(u) +} + +func CreateMeta(u *model.Meta) error { + u.Path = utils.FixAndCleanPath(u.Path) + metaCache.Del(u.Path) + return db.CreateMeta(u) +} + +func GetMetaById(id uint) (*model.Meta, error) { + return db.GetMetaById(id) +} + +func GetMetas(pageIndex, pageSize int) (metas []model.Meta, count int64, err error) { + return db.GetMetas(pageIndex, pageSize) +} diff --git a/internal/op/notify.go b/internal/op/notify.go new file mode 100644 index 0000000000000000000000000000000000000000..ff1ec31e17c57458e71642f3fe6f8fbe2bf9e7b8 --- /dev/null +++ b/internal/op/notify.go @@ -0,0 +1,526 @@ +package op + +import ( + "bytes" + "encoding/json" + "fmt" + "mime/multipart" + "net/http" + "net/url" + "reflect" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +type SendNotifyPlatform struct{} + +// 注意映射方法名必需大写要不然找不到 +func (e SendNotifyPlatform) Bark(body string, title string, content string) (bool, error) { + var bark model.Bark + err := json.Unmarshal([]byte(body), &bark) + if err != nil { + log.Errorln("无法解析配置文件") + return false, errors.Errorf("无法解析配置文件") + } + + if len(bark.BarkPush) < 2 { + log.Errorln("请正确设置BarkPush") + return false, errors.Errorf("请正确设置BarkPush") + } + + if !strings.HasPrefix(bark.BarkPush, "http") { + bark.BarkPush = fmt.Sprintf("https://api.day.app/%s", bark.BarkPush) + } + urlValues := url.Values{} + urlValues.Set("icon", bark.BarkIcon) + urlValues.Set("sound", bark.BarkSound) + urlValues.Set("group", bark.BarkGroup) + urlValues.Set("level", bark.BarkLevel) + urlValues.Set("url", bark.BarkUrl) + url := fmt.Sprintf("%s/%s/%s?%s", bark.BarkPush, url.QueryEscape(title), url.QueryEscape(content), urlValues.Encode()) + + resp, err := http.Get(url) + if err != nil { + log.Error("通知发送失败") + return false, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return true, nil + } else { + return false, fmt.Errorf("Unexpected status code: %d", resp.StatusCode) + } +} + +// 注意映射方法名必需大写要不然找不到 +func (e SendNotifyPlatform) Gotify(body string, title string, content string) (bool, error) { + m := make(map[string]interface{}) + err := json.Unmarshal([]byte(body), &m) + gotifyUrl, gotifyToken, gotifyPriority := m["gotifyUrl"].(string), m["gotifyToken"].(string), m["gotifyPriority"].(string) + + surl := fmt.Sprintf("%s/message?token=%s", gotifyUrl, gotifyToken) + data := url.Values{} + data.Set("title", title) + data.Set("message", content) + data.Set("priority", fmt.Sprintf("%d", gotifyPriority)) + + req, err := http.NewRequest("POST", surl, bytes.NewBufferString(data.Encode())) + if err != nil { + return false, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return false, err + } + defer resp.Body.Close() + + if err != nil { + log.Error("通知发送失败") + return false, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return true, nil + } else { + return false, fmt.Errorf("Unexpected status code: %d", resp.StatusCode) + } +} + +// 注意映射方法名必需大写要不然找不到 +func (e SendNotifyPlatform) GoCqHttpBot(body string, title string, content string) (bool, error) { + m := make(map[string]interface{}) + err := json.Unmarshal([]byte(body), &m) + goCqHttpBotUrl, goCqHttpBotToken, goCqHttpBotQq := m["goCqHttpBotUrl"].(string), m["goCqHttpBotToken"].(string), m["goCqHttpBotQq"].(string) + + surl := fmt.Sprintf("%s?user_id=%s", goCqHttpBotUrl, goCqHttpBotQq) + data := map[string]string{"message": fmt.Sprintf("%s\n%s", title, content)} + jsonData, err := json.Marshal(data) + if err != nil { + return false, err + } + + req, err := http.NewRequest("POST", surl, bytes.NewBuffer(jsonData)) + if err != nil { + return false, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+goCqHttpBotToken) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return false, err + } + defer resp.Body.Close() + + if err != nil { + log.Error("通知发送失败") + return false, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return true, nil + } else { + return false, fmt.Errorf("Unexpected status code: %d", resp.StatusCode) + } +} + +// 注意映射方法名必需大写要不然找不到 +func (e SendNotifyPlatform) ServerChan(body string, title string, content string) (bool, error) { + m := make(map[string]interface{}) + err := json.Unmarshal([]byte(body), &m) + serverChanKey := m["serverChanKey"].(string) + + surl := "" + if len(serverChanKey) >= 3 && serverChanKey[:3] == "SCT" { + surl = fmt.Sprintf("https://sctapi.ftqq.com/%s.send", serverChanKey) + } else { + surl = fmt.Sprintf("https://sc.ftqq.com/%s.send", serverChanKey) + } + + data := url.Values{} + data.Set("title", title) + data.Set("desp", content) + + req, err := http.NewRequest("POST", surl, bytes.NewBufferString(data.Encode())) + if err != nil { + return false, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return false, err + } + defer resp.Body.Close() + + if err != nil { + log.Error("通知发送失败") + return false, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return true, nil + } else { + return false, fmt.Errorf("Unexpected status code: %d", resp.StatusCode) + } +} + +// 注意映射方法名必需大写要不然找不到 +func (e SendNotifyPlatform) PushDeer(body string, title string, content string) (bool, error) { + m := make(map[string]interface{}) + err := json.Unmarshal([]byte(body), &m) + pushDeerKey, pushDeerUrl := m["pushDeerKey"].(string), m["pushDeerUrl"].(string) + + surl := pushDeerUrl + if surl == "" { + surl = "https://api2.pushdeer.com/message/push" + } + + data := url.Values{} + data.Set("pushkey", pushDeerKey) + data.Set("text", title) + data.Set("desp", content) + data.Set("type", "markdown") + + req, err := http.NewRequest("POST", surl, bytes.NewBufferString(data.Encode())) + if err != nil { + return false, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return false, err + } + defer resp.Body.Close() + + if err != nil { + log.Error("通知发送失败") + return false, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return true, nil + } else { + return false, fmt.Errorf("Unexpected status code: %d", resp.StatusCode) + } +} + +// 注意映射方法名必需大写要不然找不到 +func (e SendNotifyPlatform) TelegramBot(body string, title string, content string) (bool, error) { + m := make(map[string]interface{}) + err := json.Unmarshal([]byte(body), &m) + telegramBotToken, telegramBotUserId, telegramBotProxyHost, telegramBotProxyPort, telegramBotProxyAuth, telegramBotApiHost := m["telegramBotToken"].(string), m["telegramBotUserId"].(string), m["telegramBotProxyHost"].(string), m["telegramBotProxyPort"].(string), m["telegramBotProxyAuth"].(string), m["telegramBotApiHost"].(string) + + if telegramBotApiHost == "" { + telegramBotApiHost = "https://api.telegram.org" + } + + surl := fmt.Sprintf("%s/bot%s/sendMessage", telegramBotApiHost, telegramBotToken) + + var client *http.Client + if telegramBotProxyHost != "" && telegramBotProxyPort != "" { + proxyURL := fmt.Sprintf("http://%s:%s", telegramBotProxyHost, telegramBotProxyPort) + if telegramBotProxyAuth != "" { + proxyURL = fmt.Sprintf("http://%s@%s:%s", telegramBotProxyAuth, telegramBotProxyHost, telegramBotProxyPort) + } + + proxy := func(_ *http.Request) (*url.URL, error) { + return url.Parse(proxyURL) + } + + client = &http.Client{ + Transport: &http.Transport{ + Proxy: proxy, + }, + } + } else { + client = http.DefaultClient + } + + data := url.Values{} + data.Set("chat_id", telegramBotUserId) + data.Set("text", fmt.Sprintf("%s\n\n%s", title, content)) + data.Set("disable_web_page_preview", "true") + + req, err := http.NewRequest("POST", surl, strings.NewReader(data.Encode())) + if err != nil { + return false, err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := client.Do(req) + if err != nil { + return false, err + } + defer resp.Body.Close() + + if err != nil { + log.Error("通知发送失败") + return false, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return true, nil + } else { + return false, fmt.Errorf("Unexpected status code: %d", resp.StatusCode) + } +} + +// 注意映射方法名必需大写要不然找不到 +func (e SendNotifyPlatform) WeWorkBot(body string, title string, content string) (bool, error) { + m := make(map[string]interface{}) + err := json.Unmarshal([]byte(body), &m) + weWorkBotKey, weWorkOrigin := m["weWorkBotKey"].(string), m["weWorkOrigin"].(string) + + if weWorkOrigin == "" { + weWorkOrigin = "https://qyapi.weixin.qq.com" + } + + surl := fmt.Sprintf("%s/cgi-bin/webhook/send?key=%s", weWorkOrigin, weWorkBotKey) + + bodyData := map[string]interface{}{ + "msgtype": "text", + "text": map[string]string{ + "content": fmt.Sprintf("%s\n\n%s", title, content), + }, + } + data, err := json.Marshal(bodyData) + if err != nil { + return false, err + } + + var client *http.Client + client = http.DefaultClient + + req, err := http.NewRequest("POST", surl, strings.NewReader(string(data))) + if err != nil { + return false, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return false, err + } + defer resp.Body.Close() + + if err != nil { + log.Error("通知发送失败") + return false, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return true, nil + } else { + return false, fmt.Errorf("Unexpected status code: %d", resp.StatusCode) + } +} + +// 注意映射方法名必需大写要不然找不到 +func (e SendNotifyPlatform) Webhook(body string, title string, content string) (bool, error) { + var webhook model.Webhook + err := json.Unmarshal([]byte(body), &webhook) + webhookBodyString := string(webhook.WebhookBody) + if err != nil { + log.Errorln("无法解析配置文件") + return false, errors.New("无法解析配置文件") + } + + if !strings.Contains(webhook.WebhookUrl, "$title") && !strings.Contains(webhookBodyString, "$title") { + return false, errors.New("URL 或者 Body 中必须包含 $title") + } + + headers := make(map[string]string) + if len(webhook.WebhookHeaders) > 2 { + // 按换行符分割字符串 + headerLines := strings.Split(webhook.WebhookHeaders, "\n") + // 遍历每一行 + for _, line := range headerLines { + // 忽略空行 + if line == "" { + continue + } + // 按冒号分割键值对 + parts := strings.SplitN(line, ":", 2) + if len(parts) != 2 { + return false, fmt.Errorf("malformed header: %s", line) + } + // 去除键和值两端的空白字符 + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + // 将键值对添加到 map 中 + headers[key] = value + } + } + targetBody := strings.ReplaceAll(strings.ReplaceAll(webhookBodyString, "$title", title), "$content", content) + rbodys := make(map[string]string) + if len(webhook.WebhookHeaders) > 2 { + // 按换行符分割字符串 + headerLines := strings.Split(targetBody, "\n") + // 遍历每一行 + for _, line := range headerLines { + // 忽略空行 + if line == "" { + continue + } + // 按冒号分割键值对 + parts := strings.SplitN(line, ":", 2) + if len(parts) != 2 { + return false, fmt.Errorf("malformed header: %s", line) + } + // 去除键和值两端的空白字符 + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + // 将键值对添加到 map 中 + rbodys[key] = value + } + } + + var fbody *bytes.Buffer + switch webhook.WebhookContentType { + case "application/json": + fbody, err = formatJSON(rbodys) + case "multipart/form-data": + fbody, err = formatMultipart(rbodys) + case "application/x-www-form-urlencoded", "text/plain": + fbody, err = formatURLForm(rbodys) + default: + fmt.Println("Unsupported content type") + return false, nil + } + // if err != nil { + // log.Errorln("WebhookBody解析失败") + // return false, errors.New("WebhookBody解析失败") + // } + formatURL := strings.ReplaceAll(strings.ReplaceAll(webhook.WebhookUrl, "$title", url.QueryEscape(title)), "$content", url.QueryEscape(content)) + client := &http.Client{} + req, err := http.NewRequest(webhook.WebhookMethod, formatURL, fbody) + if err != nil { + log.Errorln("Webhook创建请求失败") + return false, err + } + for key, value := range headers { + req.Header.Set(key, value) + } + + resp, err := client.Do(req) + if err != nil { + return false, err + } + defer resp.Body.Close() + + if err != nil { + log.Error("通知发送失败") + return false, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return true, nil + } else { + return false, fmt.Errorf("Unexpected status code: %d", resp.StatusCode) + } +} + +var platform model.SettingItem + +func Notify(title string, content string) { + platform, err := GetSettingItemByKey(conf.NotifyPlatform) + enable, err := GetSettingItemByKey(conf.NotifyEnabled) + notifyBody, err := GetSettingItemByKey(conf.NotifyValue) + + if err != nil { + log.Error("无法找到配置信息") + } + if enable.Value != "true" && enable.Value != "1" { + log.Debug("未开启消息推送功能") + return + } + + if !conf.Conf.Notify { + log.Debug("配置文件禁用通知") + return + } + + caser := cases.Title(language.English) + methodName := caser.String(platform.Value) + + //注意映射方法名必需大写要不然找不到 + // 使用反射获取结构体实例的值 + v := reflect.ValueOf(SendNotifyPlatform{}) + // 检查是否成功获取结构体实例的值 + if v.IsValid() { + log.Debug("成功获取结构体实例的值") + } else { + log.Debug("未能获取结构体实例的值") + return + } + + method := v.MethodByName(methodName) + // 检查方法是否存在 + if !method.IsValid() { + log.Debug("Method %s not found\n", methodName) + return + } + args := []reflect.Value{reflect.ValueOf(notifyBody.Value), reflect.ValueOf(title), reflect.ValueOf(content)} + // 调用方法 + + method.Call(args) +} + +// formatJSON 格式化为 JSON +func formatJSON(bodys map[string]string) (*bytes.Buffer, error) { + jsonData, err := json.Marshal(bodys) + if err != nil { + return nil, err + } + return bytes.NewBuffer(jsonData), nil +} + +// formatMultipart 格式化为 multipart/form-data +func formatMultipart(bodys map[string]string) (*bytes.Buffer, error) { + var b bytes.Buffer + writer := multipart.NewWriter(&b) + for key, value := range bodys { + err := writer.WriteField(key, value) + if err != nil { + return nil, err + } + } + err := writer.Close() + if err != nil { + return nil, err + } + + return &b, nil +} + +// formatURLForm 格式化为 application/x-www-form-urlencoded +func formatURLForm(bodys map[string]string) (*bytes.Buffer, error) { + values := url.Values{} + for key, value := range bodys { + values.Add(key, value) + } + formData := values.Encode() + return bytes.NewBufferString(formData), nil +} diff --git a/internal/op/path.go b/internal/op/path.go new file mode 100644 index 0000000000000000000000000000000000000000..27f7e1832282343b8409ecdfe51addcf36fba9d5 --- /dev/null +++ b/internal/op/path.go @@ -0,0 +1,29 @@ +package op + +import ( + "github.com/alist-org/alist/v3/internal/errs" + "strings" + + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/pkg/utils" + log "github.com/sirupsen/logrus" +) + +// GetStorageAndActualPath Get the corresponding storage and actual path +// for path: remove the mount path prefix and join the actual root folder if exists +func GetStorageAndActualPath(rawPath string) (storage driver.Driver, actualPath string, err error) { + rawPath = utils.FixAndCleanPath(rawPath) + storage = GetBalancedStorage(rawPath) + if storage == nil { + if rawPath == "/" { + err = errs.NewErr(errs.StorageNotFound, "please add a storage first") + return + } + err = errs.NewErr(errs.StorageNotFound, "rawPath: %s", rawPath) + return + } + log.Debugln("use storage: ", storage.GetStorage().MountPath) + mountPath := utils.GetActualMountPath(storage.GetStorage().MountPath) + actualPath = utils.FixAndCleanPath(strings.TrimPrefix(rawPath, mountPath)) + return +} diff --git a/internal/op/setting.go b/internal/op/setting.go new file mode 100644 index 0000000000000000000000000000000000000000..83d19c12fbe5da787abbc1d3be862f091f748fcb --- /dev/null +++ b/internal/op/setting.go @@ -0,0 +1,198 @@ +package op + +import ( + "sort" + "strconv" + "strings" + "time" + + "github.com/Xhofe/go-cache" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/singleflight" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" +) + +var settingCache = cache.NewMemCache(cache.WithShards[*model.SettingItem](4)) +var settingG singleflight.Group[*model.SettingItem] +var settingCacheF = func(item *model.SettingItem) { + settingCache.Set(item.Key, item, cache.WithEx[*model.SettingItem](time.Hour)) +} + +var settingGroupCache = cache.NewMemCache(cache.WithShards[[]model.SettingItem](4)) +var settingGroupG singleflight.Group[[]model.SettingItem] +var settingGroupCacheF = func(key string, item []model.SettingItem) { + settingGroupCache.Set(key, item, cache.WithEx[[]model.SettingItem](time.Hour)) +} + +func settingCacheUpdate() { + settingCache.Clear() + settingGroupCache.Clear() +} + +func GetPublicSettingsMap() map[string]string { + items, _ := GetPublicSettingItems() + pSettings := make(map[string]string) + for _, item := range items { + pSettings[item.Key] = item.Value + } + return pSettings +} + +func GetSettingsMap() map[string]string { + items, _ := GetSettingItems() + settings := make(map[string]string) + for _, item := range items { + settings[item.Key] = item.Value + } + return settings +} + +func GetSettingItems() ([]model.SettingItem, error) { + if items, ok := settingGroupCache.Get("ALL_SETTING_ITEMS"); ok { + return items, nil + } + items, err, _ := settingGroupG.Do("ALL_SETTING_ITEMS", func() ([]model.SettingItem, error) { + _items, err := db.GetSettingItems() + if err != nil { + return nil, err + } + settingGroupCacheF("ALL_SETTING_ITEMS", _items) + return _items, nil + }) + return items, err +} + +func GetPublicSettingItems() ([]model.SettingItem, error) { + if items, ok := settingGroupCache.Get("ALL_PUBLIC_SETTING_ITEMS"); ok { + return items, nil + } + items, err, _ := settingGroupG.Do("ALL_PUBLIC_SETTING_ITEMS", func() ([]model.SettingItem, error) { + _items, err := db.GetPublicSettingItems() + if err != nil { + return nil, err + } + settingGroupCacheF("ALL_PUBLIC_SETTING_ITEMS", _items) + return _items, nil + }) + return items, err +} + +func GetSettingItemByKey(key string) (*model.SettingItem, error) { + if item, ok := settingCache.Get(key); ok { + return item, nil + } + + item, err, _ := settingG.Do(key, func() (*model.SettingItem, error) { + _item, err := db.GetSettingItemByKey(key) + if err != nil { + return nil, err + } + settingCacheF(_item) + return _item, nil + }) + return item, err +} + +func GetSettingItemInKeys(keys []string) ([]model.SettingItem, error) { + var items []model.SettingItem + for _, key := range keys { + item, err := GetSettingItemByKey(key) + if err != nil { + return nil, err + } + items = append(items, *item) + } + return items, nil +} + +func GetSettingItemsByGroup(group int) ([]model.SettingItem, error) { + key := strconv.Itoa(group) + if items, ok := settingGroupCache.Get(key); ok { + return items, nil + } + items, err, _ := settingGroupG.Do(key, func() ([]model.SettingItem, error) { + _items, err := db.GetSettingItemsByGroup(group) + if err != nil { + return nil, err + } + settingGroupCacheF(key, _items) + return _items, nil + }) + return items, err +} + +func GetSettingItemsInGroups(groups []int) ([]model.SettingItem, error) { + sort.Ints(groups) + key := strings.Join(utils.MustSliceConvert(groups, func(i int) string { + return strconv.Itoa(i) + }), ",") + + if items, ok := settingGroupCache.Get(key); ok { + return items, nil + } + items, err, _ := settingGroupG.Do(key, func() ([]model.SettingItem, error) { + _items, err := db.GetSettingItemsInGroups(groups) + if err != nil { + return nil, err + } + settingGroupCacheF(key, _items) + return _items, nil + }) + return items, err +} + +func SaveSettingItems(items []model.SettingItem) error { + noHookItems := make([]model.SettingItem, 0) + errs := make([]error, 0) + for i := range items { + if ok, err := HandleSettingItemHook(&items[i]); ok { + if err != nil { + errs = append(errs, err) + } else { + err = db.SaveSettingItem(&items[i]) + if err != nil { + errs = append(errs, err) + } + } + } else { + noHookItems = append(noHookItems, items[i]) + } + } + if len(noHookItems) > 0 { + err := db.SaveSettingItems(noHookItems) + if err != nil { + errs = append(errs, err) + } + } + if len(errs) < len(items)-len(noHookItems)+1 { + settingCacheUpdate() + } + return utils.MergeErrors(errs...) +} + +func SaveSettingItem(item *model.SettingItem) (err error) { + // hook + if _, err := HandleSettingItemHook(item); err != nil { + return err + } + // update + if err = db.SaveSettingItem(item); err != nil { + return err + } + settingCacheUpdate() + return nil +} + +func DeleteSettingItemByKey(key string) error { + old, err := GetSettingItemByKey(key) + if err != nil { + return errors.WithMessage(err, "failed to get settingItem") + } + if !old.IsDeprecated() { + return errors.Errorf("setting [%s] is not deprecated", key) + } + settingCacheUpdate() + return db.DeleteSettingItemByKey(key) +} diff --git a/internal/op/storage.go b/internal/op/storage.go new file mode 100644 index 0000000000000000000000000000000000000000..22ace15f74552c19eacaa1270506d21a9bf4c199 --- /dev/null +++ b/internal/op/storage.go @@ -0,0 +1,472 @@ +package op + +import ( + "context" + "encoding/json" + "sort" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/generic_sync" + "github.com/alist-org/alist/v3/pkg/utils" + mapset "github.com/deckarep/golang-set/v2" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +// Although the driver type is stored, +// there is a storage in each driver, +// so it should actually be a storage, just wrapped by the driver +var storagesMap generic_sync.MapOf[string, driver.Driver] + +func GetAllStorages() []driver.Driver { + return storagesMap.Values() +} + +func HasStorage(mountPath string) bool { + return storagesMap.Has(utils.FixAndCleanPath(mountPath)) +} + +func GetStorageByMountPath(mountPath string) (driver.Driver, error) { + mountPath = utils.FixAndCleanPath(mountPath) + storageDriver, ok := storagesMap.Load(mountPath) + if !ok { + return nil, errors.Errorf("no mount path for an storage is: %s", mountPath) + } + return storageDriver, nil +} + +// CreateStorage Save the storage to database so storage can get an id +// then instantiate corresponding driver and save it in memory +func CreateStorage(ctx context.Context, storage model.Storage) (uint, error) { + storage.Modified = time.Now() + storage.MountPath = utils.FixAndCleanPath(storage.MountPath) + var err error + // check driver first + driverName := storage.Driver + driverNew, err := GetDriver(driverName) + if err != nil { + return 0, errors.WithMessage(err, "failed get driver new") + } + storageDriver := driverNew() + // insert storage to database + err = db.CreateStorage(&storage) + if err != nil { + return storage.ID, errors.WithMessage(err, "failed create storage in database") + } + // already has an id + err = initStorage(ctx, storage, storageDriver) + go callStorageHooks("add", storageDriver) + if err != nil { + return storage.ID, errors.Wrap(err, "failed init storage but storage is already created") + } + log.Debugf("storage %+v is created", storageDriver) + return storage.ID, nil +} + +// 根据ID复制存储 +func CopyStorageById(ctx context.Context, id uint) (uint, error) { + storage, err := db.GetStorageById(id) + if err != nil { + return 0, errors.WithMessage(err, "copied get storage") + } + // 将 Storage 转换为 JSON + jsonData, _ := json.Marshal(storage) + storage_json := string(jsonData) + var data map[string]interface{} + err = json.Unmarshal([]byte(storage_json), &data) + if err != nil { + return 0, errors.WithMessage(err, "解析存储失败") + } + // Step 2: 移除ID字段 + delete(data, "id") + // Step 3: 将修改后的数据结构编码为 JSON 字符串 + result, err := json.Marshal(data) + if err != nil { + return 0, errors.WithMessage(err, "解析存储失败") + } + var new_storage model.Storage + err = json.Unmarshal([]byte(result), &new_storage) + if err != nil { + return 0, errors.WithMessage(err, "解析新存储失败") + } + + // check driver first + new_storage.MountPath = storage.MountPath + "_copyed2" + new_storage.Modified = time.Now() + new_storage.MountPath = utils.FixAndCleanPath(new_storage.MountPath) + driverName := new_storage.Driver + driverNew, err := GetDriver(driverName) + if err != nil { + return 0, errors.WithMessage(err, "failed get driver new") + } + storageDriver := driverNew() + // insert storage to database + err = db.CreateStorage(&new_storage) + if err != nil { + return new_storage.ID, errors.WithMessage(err, "failed create storage in database") + } + // already has an id + err = initStorage(ctx, new_storage, storageDriver) + go callStorageHooks("add", storageDriver) + if err != nil { + return new_storage.ID, errors.Wrap(err, "failed init storage but storage is already created") + } + log.Debugf("storage %+v is created", storageDriver) + return new_storage.ID, nil +} + +// LoadStorage load exist storage in db to memory +func LoadStorage(ctx context.Context, storage model.Storage) error { + storage.MountPath = utils.FixAndCleanPath(storage.MountPath) + // check driver first + driverName := storage.Driver + driverNew, err := GetDriver(driverName) + if err != nil { + return errors.WithMessage(err, "failed get driver new") + } + storageDriver := driverNew() + + err = initStorage(ctx, storage, storageDriver) + go callStorageHooks("add", storageDriver) + log.Debugf("storage %+v is created", storageDriver) + return err +} + +// initStorage initialize the driver and store to storagesMap +func initStorage(ctx context.Context, storage model.Storage, storageDriver driver.Driver) (err error) { + storageDriver.SetStorage(storage) + driverStorage := storageDriver.GetStorage() + + // Unmarshal Addition + err = utils.Json.UnmarshalFromString(driverStorage.Addition, storageDriver.GetAddition()) + if err == nil { + err = storageDriver.Init(ctx) + } + storagesMap.Store(driverStorage.MountPath, storageDriver) + if err != nil { + driverStorage.SetStatus(err.Error()) + err = errors.Wrap(err, "failed init storage") + } else { + driverStorage.SetStatus(WORK) + } + MustSaveDriverStorage(storageDriver) + return err +} + +func EnableStorage(ctx context.Context, id uint) error { + storage, err := db.GetStorageById(id) + if err != nil { + return errors.WithMessage(err, "failed get storage") + } + if !storage.Disabled { + return errors.Errorf("this storage have enabled") + } + storage.Disabled = false + err = db.UpdateStorage(storage) + if err != nil { + return errors.WithMessage(err, "failed update storage in db") + } + err = LoadStorage(ctx, *storage) + if err != nil { + return errors.WithMessage(err, "failed load storage") + } + return nil +} + +func DisableStorage(ctx context.Context, id uint) error { + storage, err := db.GetStorageById(id) + if err != nil { + return errors.WithMessage(err, "failed get storage") + } + if storage.Disabled { + return errors.Errorf("this storage have disabled") + } + storageDriver, err := GetStorageByMountPath(storage.MountPath) + if err != nil { + return errors.WithMessage(err, "failed get storage driver") + } + // drop the storage in the driver + if err := storageDriver.Drop(ctx); err != nil { + return errors.Wrap(err, "failed drop storage") + } + // delete the storage in the memory + storage.Disabled = true + storage.SetStatus(DISABLED) + err = db.UpdateStorage(storage) + if err != nil { + return errors.WithMessage(err, "failed update storage in db") + } + storagesMap.Delete(storage.MountPath) + go callStorageHooks("del", storageDriver) + return nil +} + +// UpdateStorage update storage +// get old storage first +// drop the storage then reinitialize +func UpdateStorage(ctx context.Context, storage model.Storage) error { + oldStorage, err := db.GetStorageById(storage.ID) + if err != nil { + return errors.WithMessage(err, "failed get old storage") + } + if oldStorage.Driver != storage.Driver { + return errors.Errorf("driver cannot be changed") + } + + if storage.SyncGroup { + storage.Modified = time.Now() + storage.SyncGroup = false + storage.MountPath = utils.FixAndCleanPath(storage.MountPath) + + // 对比新旧存储获取修改的部分对比字段为:order,cache_expiration,remark,group采取直接替换的办法,目前没有实现有需要的话再说 + // 目前只修改addition序列化对比然后替换相关字段 + var changeMap = make(map[string]interface{}) // 声明一个map用来记录变化数据 + var storageMap map[string]interface{} // 使用一个空接口表示可以是任意类型 + storageAdditionStr := storage.Addition + err := json.Unmarshal([]byte(storageAdditionStr), &storageMap) + if err != nil { + return errors.Errorf("反序列化新存储失败") + } + + var oldStorageMap map[string]interface{} // 使用一个空接口表示可以是任意类型 + oldStorageAdditionStr := oldStorage.Addition + err = json.Unmarshal([]byte(oldStorageAdditionStr), &oldStorageMap) + if err != nil { + return errors.Errorf("反序列化旧存储失败") + } + + for key, value := range storageMap { + oldValue := oldStorageMap[key] + if oldValue != value { + //changeMap[oldValue.(string)] = value.(string) + changeMap[key] = value + } + } + + if len(changeMap) == 0 { + return errors.Errorf("Addition信息未发生变化,如需修改请关闭同步组存储选项!!!") + } + + update_err := db.UpdateGroupStorages(storage.Group, changeMap) + if update_err != nil { + return errors.WithMessage(err, "更新同组存储数据失败") + } + //同组Addition数据修改完毕 + + // err = db.UpdateStorage(&storage) + // if err != nil { + // return errors.WithMessage(err, "failed update storage in database") + // } + if storage.Disabled { + return nil + } + if oldStorage.MountPath != storage.MountPath { + // mount path renamed, need to drop the storage + storagesMap.Delete(oldStorage.MountPath) + } + + storages, err := db.GetGroupStorages(storage.Group) + go func(storages []model.Storage) { + for _, storage := range storages { + storageDriver, err := GetStorageByMountPath(storage.MountPath) + if err != nil { + log.Errorf("failed get storage driver: %+v", err) + continue + } + // drop the storage in the driver + if err := storageDriver.Drop(context.Background()); err != nil { + log.Errorf("failed drop storage: %+v", err) + continue + } + if err := LoadStorage(context.Background(), storage); err != nil { + log.Errorf("failed get enabled storages: %+v", err) + continue + } + log.Infof("success load storage: [%s], driver: [%s]", + storage.MountPath, storage.Driver) + } + conf.StoragesLoaded = true + }(storages) + + // storageDriver, err := GetStorageByMountPath(oldStorage.MountPath) + // if err != nil { + // return errors.WithMessage(err, "failed get storage driver") + // } + // err = storageDriver.Drop(ctx) + // if err != nil { + // return errors.Wrapf(err, "failed drop storage") + // } + + // err = initStorage(ctx, storage, storageDriver) + // go callStorageHooks("update", storageDriver) + // log.Debugf("storage %+v is update", storageDriver) + + return err + } else { + storage.Modified = time.Now() + storage.MountPath = utils.FixAndCleanPath(storage.MountPath) + storage.SyncGroup = false + err = db.UpdateStorage(&storage) + if err != nil { + return errors.WithMessage(err, "failed update storage in database") + } + if storage.Disabled { + return nil + } + storageDriver, err := GetStorageByMountPath(oldStorage.MountPath) + if oldStorage.MountPath != storage.MountPath { + // mount path renamed, need to drop the storage + storagesMap.Delete(oldStorage.MountPath) + } + if err != nil { + return errors.WithMessage(err, "failed get storage driver") + } + err = storageDriver.Drop(ctx) + if err != nil { + return errors.Wrapf(err, "failed drop storage") + } + + err = initStorage(ctx, storage, storageDriver) + go callStorageHooks("update", storageDriver) + log.Debugf("storage %+v is update", storageDriver) + return err + } + +} + +func DeleteStorageById(ctx context.Context, id uint) error { + storage, err := db.GetStorageById(id) + if err != nil { + return errors.WithMessage(err, "failed get storage") + } + if !storage.Disabled { + storageDriver, err := GetStorageByMountPath(storage.MountPath) + if err != nil { + return errors.WithMessage(err, "failed get storage driver") + } + // drop the storage in the driver + if err := storageDriver.Drop(ctx); err != nil { + return errors.Wrapf(err, "failed drop storage") + } + // delete the storage in the memory + storagesMap.Delete(storage.MountPath) + go callStorageHooks("del", storageDriver) + } + // delete the storage in the database + if err := db.DeleteStorageById(id); err != nil { + return errors.WithMessage(err, "failed delete storage in database") + } + return nil +} + +// MustSaveDriverStorage call from specific driver +func MustSaveDriverStorage(driver driver.Driver) { + err := saveDriverStorage(driver) + if err != nil { + log.Errorf("failed save driver storage: %s", err) + } +} + +func saveDriverStorage(driver driver.Driver) error { + storage := driver.GetStorage() + addition := driver.GetAddition() + str, err := utils.Json.MarshalToString(addition) + if err != nil { + return errors.Wrap(err, "error while marshal addition") + } + storage.Addition = str + err = db.UpdateStorage(storage) + if err != nil { + return errors.WithMessage(err, "failed update storage in database") + } + return nil +} + +// getStoragesByPath get storage by longest match path, contains balance storage. +// for example, there is /a/b,/a/c,/a/d/e,/a/d/e.balance +// getStoragesByPath(/a/d/e/f) => /a/d/e,/a/d/e.balance +func getStoragesByPath(path string) []driver.Driver { + storages := make([]driver.Driver, 0) + curSlashCount := 0 + storagesMap.Range(func(mountPath string, value driver.Driver) bool { + mountPath = utils.GetActualMountPath(mountPath) + // is this path + if utils.IsSubPath(mountPath, path) { + slashCount := strings.Count(utils.PathAddSeparatorSuffix(mountPath), "/") + // not the longest match + if slashCount > curSlashCount { + storages = storages[:0] + curSlashCount = slashCount + } + if slashCount == curSlashCount { + storages = append(storages, value) + } + } + return true + }) + // make sure the order is the same for same input + sort.Slice(storages, func(i, j int) bool { + return storages[i].GetStorage().MountPath < storages[j].GetStorage().MountPath + }) + return storages +} + +// GetStorageVirtualFilesByPath Obtain the virtual file generated by the storage according to the path +// for example, there are: /a/b,/a/c,/a/d/e,/a/b.balance1,/av +// GetStorageVirtualFilesByPath(/a) => b,c,d +func GetStorageVirtualFilesByPath(prefix string) []model.Obj { + files := make([]model.Obj, 0) + storages := storagesMap.Values() + sort.Slice(storages, func(i, j int) bool { + if storages[i].GetStorage().Order == storages[j].GetStorage().Order { + return storages[i].GetStorage().MountPath < storages[j].GetStorage().MountPath + } + return storages[i].GetStorage().Order < storages[j].GetStorage().Order + }) + + prefix = utils.FixAndCleanPath(prefix) + set := mapset.NewSet[string]() + for _, v := range storages { + mountPath := utils.GetActualMountPath(v.GetStorage().MountPath) + // Exclude prefix itself and non prefix + if len(prefix) >= len(mountPath) || !utils.IsSubPath(prefix, mountPath) { + continue + } + name := strings.SplitN(strings.TrimPrefix(mountPath[len(prefix):], "/"), "/", 2)[0] + if set.Add(name) { + files = append(files, &model.Object{ + Name: name, + Size: 0, + Modified: v.GetStorage().Modified, + IsFolder: true, + }) + } + } + return files +} + +var balanceMap generic_sync.MapOf[string, int] + +// GetBalancedStorage get storage by path +func GetBalancedStorage(path string) driver.Driver { + path = utils.FixAndCleanPath(path) + storages := getStoragesByPath(path) + storageNum := len(storages) + switch storageNum { + case 0: + return nil + case 1: + return storages[0] + default: + virtualPath := utils.GetActualMountPath(storages[0].GetStorage().MountPath) + i, _ := balanceMap.LoadOrStore(virtualPath, 0) + i = (i + 1) % storageNum + balanceMap.Store(virtualPath, i) + return storages[i] + } +} diff --git a/internal/op/storage_test.go b/internal/op/storage_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f3d0b4e64b7988db836fe6292e7a23f55950a5bd --- /dev/null +++ b/internal/op/storage_test.go @@ -0,0 +1,91 @@ +package op_test + +import ( + "context" + "testing" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + mapset "github.com/deckarep/golang-set/v2" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func init() { + dB, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + if err != nil { + panic("failed to connect database") + } + conf.Conf = conf.DefaultConfig() + db.Init(dB) +} + +func TestCreateStorage(t *testing.T) { + var storages = []struct { + storage model.Storage + isErr bool + }{ + {storage: model.Storage{Driver: "Local", MountPath: "/local", Addition: `{"root_folder_path":"."}`}, isErr: false}, + {storage: model.Storage{Driver: "Local", MountPath: "/local", Addition: `{"root_folder_path":"."}`}, isErr: true}, + {storage: model.Storage{Driver: "None", MountPath: "/none", Addition: `{"root_folder_path":"."}`}, isErr: true}, + } + for _, storage := range storages { + _, err := op.CreateStorage(context.Background(), storage.storage) + if err != nil { + if !storage.isErr { + t.Errorf("failed to create storage: %+v", err) + } else { + t.Logf("expect failed to create storage: %+v", err) + } + } + } +} + +func TestGetStorageVirtualFilesByPath(t *testing.T) { + setupStorages(t) + virtualFiles := op.GetStorageVirtualFilesByPath("/a") + var names []string + for _, virtualFile := range virtualFiles { + names = append(names, virtualFile.GetName()) + } + var expectedNames = []string{"b", "c", "d"} + if utils.SliceEqual(names, expectedNames) { + t.Logf("passed") + } else { + t.Errorf("expected: %+v, got: %+v", expectedNames, names) + } +} + +func TestGetBalancedStorage(t *testing.T) { + setupStorages(t) + set := mapset.NewSet[string]() + for i := 0; i < 5; i++ { + storage := op.GetBalancedStorage("/a/d/e1") + set.Add(storage.GetStorage().MountPath) + } + expected := mapset.NewSet([]string{"/a/d/e1", "/a/d/e1.balance"}...) + if !expected.Equal(set) { + t.Errorf("expected: %+v, got: %+v", expected, set) + } +} + +func setupStorages(t *testing.T) { + var storages = []model.Storage{ + {Driver: "Local", MountPath: "/a/b", Order: 0, Addition: `{"root_folder_path":"."}`}, + {Driver: "Local", MountPath: "/adc", Order: 0, Addition: `{"root_folder_path":"."}`}, + {Driver: "Local", MountPath: "/a/c", Order: 1, Addition: `{"root_folder_path":"."}`}, + {Driver: "Local", MountPath: "/a/d", Order: 2, Addition: `{"root_folder_path":"."}`}, + {Driver: "Local", MountPath: "/a/d/e1", Order: 3, Addition: `{"root_folder_path":"."}`}, + {Driver: "Local", MountPath: "/a/d/e", Order: 4, Addition: `{"root_folder_path":"."}`}, + {Driver: "Local", MountPath: "/a/d/e1.balance", Order: 4, Addition: `{"root_folder_path":"."}`}, + } + for _, storage := range storages { + _, err := op.CreateStorage(context.Background(), storage) + if err != nil { + t.Fatalf("failed to create storage: %+v", err) + } + } +} diff --git a/internal/op/user.go b/internal/op/user.go new file mode 100644 index 0000000000000000000000000000000000000000..79e73db86ce1bbe26726d1635613817df34e6313 --- /dev/null +++ b/internal/op/user.go @@ -0,0 +1,130 @@ +package op + +import ( + "time" + + "github.com/Xhofe/go-cache" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/singleflight" + "github.com/alist-org/alist/v3/pkg/utils" +) + +var userCache = cache.NewMemCache(cache.WithShards[*model.User](2)) +var userG singleflight.Group[*model.User] +var guestUser *model.User +var adminUser *model.User + +func GetAdmin() (*model.User, error) { + if adminUser == nil { + user, err := db.GetUserByRole(model.ADMIN) + if err != nil { + return nil, err + } + adminUser = user + } + return adminUser, nil +} + +func GetGuest() (*model.User, error) { + if guestUser == nil { + user, err := db.GetUserByRole(model.GUEST) + if err != nil { + return nil, err + } + guestUser = user + } + return guestUser, nil +} + +func GetUserByRole(role int) (*model.User, error) { + return db.GetUserByRole(role) +} + +func GetUserByName(username string) (*model.User, error) { + if username == "" { + return nil, errs.EmptyUsername + } + if user, ok := userCache.Get(username); ok { + return user, nil + } + user, err, _ := userG.Do(username, func() (*model.User, error) { + _user, err := db.GetUserByName(username) + if err != nil { + return nil, err + } + userCache.Set(username, _user, cache.WithEx[*model.User](time.Hour)) + return _user, nil + }) + return user, err +} + +func GetUserById(id uint) (*model.User, error) { + return db.GetUserById(id) +} + +func GetUsers(pageIndex, pageSize int) (users []model.User, count int64, err error) { + return db.GetUsers(pageIndex, pageSize) +} + +func CreateUser(u *model.User) error { + u.BasePath = utils.FixAndCleanPath(u.BasePath) + return db.CreateUser(u) +} + +func DeleteUserById(id uint) error { + old, err := db.GetUserById(id) + if err != nil { + return err + } + if old.IsAdmin() || old.IsGuest() { + return errs.DeleteAdminOrGuest + } + userCache.Del(old.Username) + return db.DeleteUserById(id) +} + +func UpdateUser(u *model.User) error { + old, err := db.GetUserById(u.ID) + if err != nil { + return err + } + if u.IsAdmin() { + adminUser = nil + } + if u.IsGuest() { + guestUser = nil + } + userCache.Del(old.Username) + u.BasePath = utils.FixAndCleanPath(u.BasePath) + return db.UpdateUser(u) +} + +func Cancel2FAByUser(u *model.User) error { + u.OtpSecret = "" + return UpdateUser(u) +} + +func Cancel2FAById(id uint) error { + user, err := db.GetUserById(id) + if err != nil { + return err + } + return Cancel2FAByUser(user) +} + +func DelUserCache(username string) error { + user, err := GetUserByName(username) + if err != nil { + return err + } + if user.IsAdmin() { + adminUser = nil + } + if user.IsGuest() { + guestUser = nil + } + userCache.Del(username) + return nil +} diff --git a/internal/search/bleve/init.go b/internal/search/bleve/init.go new file mode 100644 index 0000000000000000000000000000000000000000..e764c3cf8623542b2d78336f57074b19c3bc133a --- /dev/null +++ b/internal/search/bleve/init.go @@ -0,0 +1,47 @@ +package bleve + +import ( + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/search/searcher" + "github.com/blevesearch/bleve/v2" + log "github.com/sirupsen/logrus" +) + +var config = searcher.Config{ + Name: "bleve", +} + +func Init(indexPath *string) (bleve.Index, error) { + log.Debugf("bleve path: %s", *indexPath) + fileIndex, err := bleve.Open(*indexPath) + if err == bleve.ErrorIndexPathDoesNotExist { + log.Infof("Creating new index...") + indexMapping := bleve.NewIndexMapping() + searchNodeMapping := bleve.NewDocumentMapping() + searchNodeMapping.AddFieldMappingsAt("is_dir", bleve.NewBooleanFieldMapping()) + // TODO: appoint analyzer + parentFieldMapping := bleve.NewTextFieldMapping() + searchNodeMapping.AddFieldMappingsAt("parent", parentFieldMapping) + // TODO: appoint analyzer + nameFieldMapping := bleve.NewKeywordFieldMapping() + searchNodeMapping.AddFieldMappingsAt("name", nameFieldMapping) + indexMapping.AddDocumentMapping("SearchNode", searchNodeMapping) + fileIndex, err = bleve.New(*indexPath, indexMapping) + if err != nil { + return nil, err + } + } else if err != nil { + return nil, err + } + return fileIndex, nil +} + +func init() { + searcher.RegisterSearcher(config, func() (searcher.Searcher, error) { + b, err := Init(&conf.Conf.BleveDir) + if err != nil { + return nil, err + } + return &Bleve{BIndex: b}, nil + }) +} diff --git a/internal/search/bleve/search.go b/internal/search/bleve/search.go new file mode 100644 index 0000000000000000000000000000000000000000..b69b3f298e76628b126cffd31f14cda20e590003 --- /dev/null +++ b/internal/search/bleve/search.go @@ -0,0 +1,105 @@ +package bleve + +import ( + "context" + "os" + + query2 "github.com/blevesearch/bleve/v2/search/query" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/search/searcher" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/blevesearch/bleve/v2" + search2 "github.com/blevesearch/bleve/v2/search" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" +) + +type Bleve struct { + BIndex bleve.Index +} + +func (b *Bleve) Config() searcher.Config { + return config +} + +func (b *Bleve) Search(ctx context.Context, req model.SearchReq) ([]model.SearchNode, int64, error) { + var queries []query2.Query + query := bleve.NewMatchQuery(req.Keywords) + query.SetField("name") + queries = append(queries, query) + if req.Scope != 0 { + isDir := req.Scope == 1 + isDirQuery := bleve.NewBoolFieldQuery(isDir) + queries = append(queries, isDirQuery) + } + reqQuery := bleve.NewConjunctionQuery(queries...) + search := bleve.NewSearchRequest(reqQuery) + search.SortBy([]string{"name"}) + search.From = (req.Page - 1) * req.PerPage + search.Size = req.PerPage + search.Fields = []string{"*"} + searchResults, err := b.BIndex.Search(search) + if err != nil { + log.Errorf("search error: %+v", err) + return nil, 0, err + } + res, err := utils.SliceConvert(searchResults.Hits, func(src *search2.DocumentMatch) (model.SearchNode, error) { + return model.SearchNode{ + Parent: src.Fields["parent"].(string), + Name: src.Fields["name"].(string), + IsDir: src.Fields["is_dir"].(bool), + Size: int64(src.Fields["size"].(float64)), + }, nil + }) + return res, int64(searchResults.Total), nil +} + +func (b *Bleve) Index(ctx context.Context, node model.SearchNode) error { + return b.BIndex.Index(uuid.NewString(), node) +} + +func (b *Bleve) BatchIndex(ctx context.Context, nodes []model.SearchNode) error { + batch := b.BIndex.NewBatch() + for _, node := range nodes { + batch.Index(uuid.NewString(), node) + } + return b.BIndex.Batch(batch) +} + +func (b *Bleve) Get(ctx context.Context, parent string) ([]model.SearchNode, error) { + return nil, errs.NotSupport +} + +func (b *Bleve) Del(ctx context.Context, prefix string) error { + return errs.NotSupport +} + +func (b *Bleve) Release(ctx context.Context) error { + if b.BIndex != nil { + return b.BIndex.Close() + } + return nil +} + +func (b *Bleve) Clear(ctx context.Context) error { + err := b.Release(ctx) + if err != nil { + return err + } + log.Infof("Removing old index...") + err = os.RemoveAll(conf.Conf.BleveDir) + if err != nil { + log.Errorf("clear bleve error: %+v", err) + } + bIndex, err := Init(&conf.Conf.BleveDir) + if err != nil { + return err + } + b.BIndex = bIndex + return nil +} + +var _ searcher.Searcher = (*Bleve)(nil) diff --git a/internal/search/build.go b/internal/search/build.go new file mode 100644 index 0000000000000000000000000000000000000000..1d3bfb7cd5d28340fa068e012dada4c057dce134 --- /dev/null +++ b/internal/search/build.go @@ -0,0 +1,239 @@ +package search + +import ( + "context" + "path" + "path/filepath" + "strings" + "sync/atomic" + "time" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/search/searcher" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/mq" + "github.com/alist-org/alist/v3/pkg/utils" + mapset "github.com/deckarep/golang-set/v2" + log "github.com/sirupsen/logrus" +) + +var ( + Running = atomic.Bool{} + Quit chan struct{} +) + +func BuildIndex(ctx context.Context, indexPaths, ignorePaths []string, maxDepth int, count bool) error { + var ( + err error + objCount uint64 = 0 + fi model.Obj + ) + log.Infof("build index for: %+v", indexPaths) + log.Infof("ignore paths: %+v", ignorePaths) + Running.Store(true) + Quit = make(chan struct{}, 1) + indexMQ := mq.NewInMemoryMQ[ObjWithParent]() + go func() { + ticker := time.NewTicker(time.Second) + tickCount := 0 + for { + select { + case <-ticker.C: + tickCount += 1 + if indexMQ.Len() < 1000 && tickCount != 5 { + continue + } else if tickCount >= 5 { + tickCount = 0 + } + log.Infof("index obj count: %d", objCount) + indexMQ.ConsumeAll(func(messages []mq.Message[ObjWithParent]) { + if len(messages) != 0 { + log.Debugf("current index: %s", messages[len(messages)-1].Content.Parent) + } + if err = BatchIndex(ctx, utils.MustSliceConvert(messages, + func(src mq.Message[ObjWithParent]) ObjWithParent { + return src.Content + })); err != nil { + log.Errorf("build index in batch error: %+v", err) + } else { + objCount = objCount + uint64(len(messages)) + } + if count { + WriteProgress(&model.IndexProgress{ + ObjCount: objCount, + IsDone: false, + LastDoneTime: nil, + }) + } + }) + + case <-Quit: + Running.Store(false) + ticker.Stop() + eMsg := "" + now := time.Now() + originErr := err + indexMQ.ConsumeAll(func(messages []mq.Message[ObjWithParent]) { + if err = BatchIndex(ctx, utils.MustSliceConvert(messages, + func(src mq.Message[ObjWithParent]) ObjWithParent { + return src.Content + })); err != nil { + log.Errorf("build index in batch error: %+v", err) + } else { + objCount = objCount + uint64(len(messages)) + } + if originErr != nil { + log.Errorf("build index error: %+v", originErr) + eMsg = originErr.Error() + } else { + log.Infof("success build index, count: %d", objCount) + } + if count { + WriteProgress(&model.IndexProgress{ + ObjCount: objCount, + IsDone: true, + LastDoneTime: &now, + Error: eMsg, + }) + } + }) + return + } + } + }() + defer func() { + if Running.Load() { + Quit <- struct{}{} + } + }() + admin, err := op.GetAdmin() + if err != nil { + return err + } + if count { + WriteProgress(&model.IndexProgress{ + ObjCount: 0, + IsDone: false, + }) + } + for _, indexPath := range indexPaths { + walkFn := func(indexPath string, info model.Obj) error { + if !Running.Load() { + return filepath.SkipDir + } + for _, avoidPath := range ignorePaths { + if strings.HasPrefix(indexPath, avoidPath) { + return filepath.SkipDir + } + } + // ignore root + if indexPath == "/" { + return nil + } + indexMQ.Publish(mq.Message[ObjWithParent]{ + Content: ObjWithParent{ + Obj: info, + Parent: path.Dir(indexPath), + }, + }) + return nil + } + fi, err = fs.Get(ctx, indexPath, &fs.GetArgs{}) + if err != nil { + return err + } + // TODO: run walkFS concurrently + err = fs.WalkFS(context.WithValue(ctx, "user", admin), maxDepth, indexPath, fi, walkFn) + if err != nil { + return err + } + } + return nil +} + +func Del(ctx context.Context, prefix string) error { + return instance.Del(ctx, prefix) +} + +func Clear(ctx context.Context) error { + return instance.Clear(ctx) +} + +func Config(ctx context.Context) searcher.Config { + return instance.Config() +} + +func Update(parent string, objs []model.Obj) { + if instance == nil || !instance.Config().AutoUpdate || !setting.GetBool(conf.AutoUpdateIndex) || Running.Load() { + return + } + if isIgnorePath(parent) { + return + } + ctx := context.Background() + // only update when index have built + progress, err := Progress() + if err != nil { + log.Errorf("update search index error while get progress: %+v", err) + return + } + if !progress.IsDone { + return + } + nodes, err := instance.Get(ctx, parent) + if err != nil { + log.Errorf("update search index error while get nodes: %+v", err) + return + } + now := mapset.NewSet[string]() + for i := range objs { + now.Add(objs[i].GetName()) + } + old := mapset.NewSet[string]() + for i := range nodes { + old.Add(nodes[i].Name) + } + // delete data that no longer exists + toDelete := old.Difference(now) + toAdd := now.Difference(old) + for i := range nodes { + if toDelete.Contains(nodes[i].Name) && !op.HasStorage(path.Join(parent, nodes[i].Name)) { + log.Debugf("delete index: %s", path.Join(parent, nodes[i].Name)) + err = instance.Del(ctx, path.Join(parent, nodes[i].Name)) + if err != nil { + log.Errorf("update search index error while del old node: %+v", err) + return + } + } + } + for i := range objs { + if toAdd.Contains(objs[i].GetName()) { + if !objs[i].IsDir() { + log.Debugf("add index: %s", path.Join(parent, objs[i].GetName())) + err = Index(ctx, parent, objs[i]) + if err != nil { + log.Errorf("update search index error while index new node: %+v", err) + return + } + } else { + // build index if it's a folder + dir := path.Join(parent, objs[i].GetName()) + err = BuildIndex(ctx, + []string{dir}, + conf.SlicesMap[conf.IgnorePaths], + setting.GetInt(conf.MaxIndexDepth, 20)-strings.Count(dir, "/"), false) + if err != nil { + log.Errorf("update search index error while build index: %+v", err) + return + } + } + } + } +} + +func init() { + op.RegisterObjsUpdateHook(Update) +} diff --git a/internal/search/db/init.go b/internal/search/db/init.go new file mode 100644 index 0000000000000000000000000000000000000000..b7d0288f9474a0ee87d3d0baf7c072aec141d46a --- /dev/null +++ b/internal/search/db/init.go @@ -0,0 +1,42 @@ +package db + +import ( + "fmt" + "strings" + + log "github.com/sirupsen/logrus" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/search/searcher" +) + +var config = searcher.Config{ + Name: "database", + AutoUpdate: true, +} + +func init() { + searcher.RegisterSearcher(config, func() (searcher.Searcher, error) { + db := db.GetDb() + switch conf.Conf.Database.Type { + case "mysql": + tableName := fmt.Sprintf("%ssearch_nodes", conf.Conf.Database.TablePrefix) + tx := db.Exec(fmt.Sprintf("CREATE FULLTEXT INDEX idx_%s_name_fulltext ON %s(name);", tableName, tableName)) + if err := tx.Error; err != nil && !strings.Contains(err.Error(), "Error 1061 (42000)") { // duplicate error + log.Errorf("failed to create full text index: %v", err) + return nil, err + } + case "postgres": + db.Exec("CREATE EXTENSION pg_trgm;") + db.Exec("CREATE EXTENSION btree_gin;") + tableName := fmt.Sprintf("%ssearch_nodes", conf.Conf.Database.TablePrefix) + tx := db.Exec(fmt.Sprintf("CREATE INDEX idx_%s_name ON %s USING GIN (name);", tableName, tableName)) + if err := tx.Error; err != nil && !strings.Contains(err.Error(), "SQLSTATE 42P07") { + log.Errorf("failed to create index using GIN: %v", err) + return nil, err + } + } + return &DB{}, nil + }) +} diff --git a/internal/search/db/search.go b/internal/search/db/search.go new file mode 100644 index 0000000000000000000000000000000000000000..70baef063e6ab1d76b7316858cce77ef60a37c68 --- /dev/null +++ b/internal/search/db/search.go @@ -0,0 +1,45 @@ +package db + +import ( + "context" + + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/search/searcher" +) + +type DB struct{} + +func (D DB) Config() searcher.Config { + return config +} + +func (D DB) Search(ctx context.Context, req model.SearchReq) ([]model.SearchNode, int64, error) { + return db.SearchNode(req, true) +} + +func (D DB) Index(ctx context.Context, node model.SearchNode) error { + return db.CreateSearchNode(&node) +} + +func (D DB) BatchIndex(ctx context.Context, nodes []model.SearchNode) error { + return db.BatchCreateSearchNodes(&nodes) +} + +func (D DB) Get(ctx context.Context, parent string) ([]model.SearchNode, error) { + return db.GetSearchNodesByParent(parent) +} + +func (D DB) Del(ctx context.Context, path string) error { + return db.DeleteSearchNodesByParent(path) +} + +func (D DB) Release(ctx context.Context) error { + return nil +} + +func (D DB) Clear(ctx context.Context) error { + return db.ClearSearchNodes() +} + +var _ searcher.Searcher = (*DB)(nil) diff --git a/internal/search/db_non_full_text/init.go b/internal/search/db_non_full_text/init.go new file mode 100644 index 0000000000000000000000000000000000000000..a0fcae493951609b5ee26dc4a846c907347ba710 --- /dev/null +++ b/internal/search/db_non_full_text/init.go @@ -0,0 +1,16 @@ +package db_non_full_text + +import ( + "github.com/alist-org/alist/v3/internal/search/searcher" +) + +var config = searcher.Config{ + Name: "database_non_full_text", + AutoUpdate: true, +} + +func init() { + searcher.RegisterSearcher(config, func() (searcher.Searcher, error) { + return &DB{}, nil + }) +} diff --git a/internal/search/db_non_full_text/search.go b/internal/search/db_non_full_text/search.go new file mode 100644 index 0000000000000000000000000000000000000000..8589114be925088667469e511cec46db297408fa --- /dev/null +++ b/internal/search/db_non_full_text/search.go @@ -0,0 +1,45 @@ +package db_non_full_text + +import ( + "context" + + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/search/searcher" +) + +type DB struct{} + +func (D DB) Config() searcher.Config { + return config +} + +func (D DB) Search(ctx context.Context, req model.SearchReq) ([]model.SearchNode, int64, error) { + return db.SearchNode(req, false) +} + +func (D DB) Index(ctx context.Context, node model.SearchNode) error { + return db.CreateSearchNode(&node) +} + +func (D DB) BatchIndex(ctx context.Context, nodes []model.SearchNode) error { + return db.BatchCreateSearchNodes(&nodes) +} + +func (D DB) Get(ctx context.Context, parent string) ([]model.SearchNode, error) { + return db.GetSearchNodesByParent(parent) +} + +func (D DB) Del(ctx context.Context, path string) error { + return db.DeleteSearchNodesByParent(path) +} + +func (D DB) Release(ctx context.Context) error { + return nil +} + +func (D DB) Clear(ctx context.Context) error { + return db.ClearSearchNodes() +} + +var _ searcher.Searcher = (*DB)(nil) diff --git a/internal/search/import.go b/internal/search/import.go new file mode 100644 index 0000000000000000000000000000000000000000..a34c36f9a3499c5f6d1b366d7703142e279ff6b8 --- /dev/null +++ b/internal/search/import.go @@ -0,0 +1,8 @@ +package search + +import ( + _ "github.com/alist-org/alist/v3/internal/search/bleve" + _ "github.com/alist-org/alist/v3/internal/search/db" + _ "github.com/alist-org/alist/v3/internal/search/db_non_full_text" + _ "github.com/alist-org/alist/v3/internal/search/meilisearch" +) diff --git a/internal/search/meilisearch/init.go b/internal/search/meilisearch/init.go new file mode 100644 index 0000000000000000000000000000000000000000..8f5f24733ee5e48c09c90f0586f0fadad3eba3b3 --- /dev/null +++ b/internal/search/meilisearch/init.go @@ -0,0 +1,89 @@ +package meilisearch + +import ( + "errors" + "fmt" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/search/searcher" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/meilisearch/meilisearch-go" +) + +var config = searcher.Config{ + Name: "meilisearch", + AutoUpdate: true, +} + +func init() { + searcher.RegisterSearcher(config, func() (searcher.Searcher, error) { + m := Meilisearch{ + Client: meilisearch.NewClient(meilisearch.ClientConfig{ + Host: conf.Conf.Meilisearch.Host, + APIKey: conf.Conf.Meilisearch.APIKey, + }), + IndexUid: conf.Conf.Meilisearch.IndexPrefix + "alist", + FilterableAttributes: []string{"parent", "is_dir", "name"}, + SearchableAttributes: []string{"name"}, + } + + _, err := m.Client.GetIndex(m.IndexUid) + if err != nil { + var mErr *meilisearch.Error + ok := errors.As(err, &mErr) + if ok && mErr.MeilisearchApiError.Code == "index_not_found" { + task, err := m.Client.CreateIndex(&meilisearch.IndexConfig{ + Uid: m.IndexUid, + PrimaryKey: "id", + }) + if err != nil { + return nil, err + } + forTask, err := m.Client.WaitForTask(task.TaskUID) + if err != nil { + return nil, err + } + if forTask.Status != meilisearch.TaskStatusSucceeded { + return nil, fmt.Errorf("index creation failed, task status is %s", forTask.Status) + } + } else { + return nil, err + } + } + attributes, err := m.Client.Index(m.IndexUid).GetFilterableAttributes() + if err != nil { + return nil, err + } + if attributes == nil || !utils.SliceAllContains(*attributes, m.FilterableAttributes...) { + _, err = m.Client.Index(m.IndexUid).UpdateFilterableAttributes(&m.FilterableAttributes) + if err != nil { + return nil, err + } + } + + attributes, err = m.Client.Index(m.IndexUid).GetSearchableAttributes() + if err != nil { + return nil, err + } + if attributes == nil || !utils.SliceAllContains(*attributes, m.SearchableAttributes...) { + _, err = m.Client.Index(m.IndexUid).UpdateSearchableAttributes(&m.SearchableAttributes) + if err != nil { + return nil, err + } + } + + pagination, err := m.Client.Index(m.IndexUid).GetPagination() + if err != nil { + return nil, err + } + if pagination.MaxTotalHits != int64(model.MaxInt) { + _, err := m.Client.Index(m.IndexUid).UpdatePagination(&meilisearch.Pagination{ + MaxTotalHits: int64(model.MaxInt), + }) + if err != nil { + return nil, err + } + } + return &m, nil + }) +} diff --git a/internal/search/meilisearch/search.go b/internal/search/meilisearch/search.go new file mode 100644 index 0000000000000000000000000000000000000000..1516306b75ffe7f492f8254a54e86b27d4887ac1 --- /dev/null +++ b/internal/search/meilisearch/search.go @@ -0,0 +1,227 @@ +package meilisearch + +import ( + "context" + "fmt" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/search/searcher" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/google/uuid" + "github.com/meilisearch/meilisearch-go" + "path" + "strings" + "time" +) + +type searchDocument struct { + ID string `json:"id"` + model.SearchNode +} + +type Meilisearch struct { + Client *meilisearch.Client + IndexUid string + FilterableAttributes []string + SearchableAttributes []string +} + +func (m *Meilisearch) Config() searcher.Config { + return config +} + +func (m *Meilisearch) Search(ctx context.Context, req model.SearchReq) ([]model.SearchNode, int64, error) { + mReq := &meilisearch.SearchRequest{ + AttributesToSearchOn: m.SearchableAttributes, + Page: int64(req.Page), + HitsPerPage: int64(req.PerPage), + } + if req.Scope != 0 { + mReq.Filter = fmt.Sprintf("is_dir = %v", req.Scope == 1) + } + search, err := m.Client.Index(m.IndexUid).Search(req.Keywords, mReq) + if err != nil { + return nil, 0, err + } + nodes, err := utils.SliceConvert(search.Hits, func(src any) (model.SearchNode, error) { + srcMap := src.(map[string]any) + return model.SearchNode{ + Parent: srcMap["parent"].(string), + Name: srcMap["name"].(string), + IsDir: srcMap["is_dir"].(bool), + Size: int64(srcMap["size"].(float64)), + }, nil + }) + if err != nil { + return nil, 0, err + } + return nodes, search.TotalHits, nil +} + +func (m *Meilisearch) Index(ctx context.Context, node model.SearchNode) error { + return m.BatchIndex(ctx, []model.SearchNode{node}) +} + +func (m *Meilisearch) BatchIndex(ctx context.Context, nodes []model.SearchNode) error { + documents, _ := utils.SliceConvert(nodes, func(src model.SearchNode) (*searchDocument, error) { + + return &searchDocument{ + ID: uuid.NewString(), + SearchNode: src, + }, nil + }) + + _, err := m.Client.Index(m.IndexUid).AddDocuments(documents) + if err != nil { + return err + } + + //// Wait for the task to complete and check + //forTask, err := m.Client.WaitForTask(task.TaskUID, meilisearch.WaitParams{ + // Context: ctx, + // Interval: time.Millisecond * 50, + //}) + //if err != nil { + // return err + //} + //if forTask.Status != meilisearch.TaskStatusSucceeded { + // return fmt.Errorf("BatchIndex failed, task status is %s", forTask.Status) + //} + return nil +} + +func (m *Meilisearch) getDocumentsByParent(ctx context.Context, parent string) ([]*searchDocument, error) { + var result meilisearch.DocumentsResult + err := m.Client.Index(m.IndexUid).GetDocuments(&meilisearch.DocumentsQuery{ + Filter: fmt.Sprintf("parent = '%s'", strings.ReplaceAll(parent, "'", "\\'")), + Limit: int64(model.MaxInt), + }, &result) + if err != nil { + return nil, err + } + return utils.SliceConvert(result.Results, func(src map[string]any) (*searchDocument, error) { + return &searchDocument{ + ID: src["id"].(string), + SearchNode: model.SearchNode{ + Parent: src["parent"].(string), + Name: src["name"].(string), + IsDir: src["is_dir"].(bool), + Size: int64(src["size"].(float64)), + }, + }, nil + }) +} + +func (m *Meilisearch) Get(ctx context.Context, parent string) ([]model.SearchNode, error) { + result, err := m.getDocumentsByParent(ctx, parent) + if err != nil { + return nil, err + } + return utils.SliceConvert(result, func(src *searchDocument) (model.SearchNode, error) { + return src.SearchNode, nil + }) + +} + +func (m *Meilisearch) getParentsByPrefix(ctx context.Context, parent string) ([]string, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + parents := []string{parent} + get, err := m.getDocumentsByParent(ctx, parent) + if err != nil { + return nil, err + } + for _, node := range get { + if node.IsDir { + arr, err := m.getParentsByPrefix(ctx, path.Join(node.Parent, node.Name)) + if err != nil { + return nil, err + } + parents = append(parents, arr...) + } + } + return parents, nil + } +} + +func (m *Meilisearch) DelDirChild(ctx context.Context, prefix string) error { + dfs, err := m.getParentsByPrefix(ctx, utils.FixAndCleanPath(prefix)) + if err != nil { + return err + } + utils.SliceReplace(dfs, func(src string) string { + return "'" + strings.ReplaceAll(src, "'", "\\'") + "'" + }) + s := fmt.Sprintf("parent IN [%s]", strings.Join(dfs, ",")) + task, err := m.Client.Index(m.IndexUid).DeleteDocumentsByFilter(s) + if err != nil { + return err + } + taskStatus, err := m.getTaskStatus(ctx, task.TaskUID) + if err != nil { + return err + } + if taskStatus != meilisearch.TaskStatusSucceeded { + return fmt.Errorf("DelDir failed, task status is %s", taskStatus) + } + return nil +} + +func (m *Meilisearch) Del(ctx context.Context, prefix string) error { + prefix = utils.FixAndCleanPath(prefix) + dir, name := path.Split(prefix) + get, err := m.getDocumentsByParent(ctx, dir[:len(dir)-1]) + if err != nil { + return err + } + var document *searchDocument + for _, v := range get { + if v.Name == name { + document = v + break + } + } + if document == nil { + // Defensive programming. Document may be the folder, try deleting Child + return m.DelDirChild(ctx, prefix) + } + if document.IsDir { + err = m.DelDirChild(ctx, prefix) + if err != nil { + return err + } + } + task, err := m.Client.Index(m.IndexUid).DeleteDocument(document.ID) + if err != nil { + return err + } + taskStatus, err := m.getTaskStatus(ctx, task.TaskUID) + if err != nil { + return err + } + if taskStatus != meilisearch.TaskStatusSucceeded { + return fmt.Errorf("DelDir failed, task status is %s", taskStatus) + } + return nil +} + +func (m *Meilisearch) Release(ctx context.Context) error { + return nil +} + +func (m *Meilisearch) Clear(ctx context.Context) error { + _, err := m.Client.Index(m.IndexUid).DeleteAllDocuments() + return err +} + +func (m *Meilisearch) getTaskStatus(ctx context.Context, taskUID int64) (meilisearch.TaskStatus, error) { + forTask, err := m.Client.WaitForTask(taskUID, meilisearch.WaitParams{ + Context: ctx, + Interval: time.Second, + }) + if err != nil { + return meilisearch.TaskStatusUnknown, err + } + return forTask.Status, nil +} diff --git a/internal/search/search.go b/internal/search/search.go new file mode 100644 index 0000000000000000000000000000000000000000..c1a23b85ca82d6a1768f8bdd5613c056ed70b63d --- /dev/null +++ b/internal/search/search.go @@ -0,0 +1,95 @@ +package search + +import ( + "context" + "fmt" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/search/searcher" + log "github.com/sirupsen/logrus" +) + +var instance searcher.Searcher = nil + +// Init or reset index +func Init(mode string) error { + if instance != nil { + // unchanged, do nothing + if instance.Config().Name == mode { + return nil + } + err := instance.Release(context.Background()) + if err != nil { + log.Errorf("release instance err: %+v", err) + } + instance = nil + } + if Running.Load() { + return fmt.Errorf("index is running") + } + if mode == "none" { + log.Warnf("not enable search") + return nil + } + s, ok := searcher.NewMap[mode] + if !ok { + return fmt.Errorf("not support index: %s", mode) + } + i, err := s() + if err != nil { + log.Errorf("init searcher error: %+v", err) + } else { + instance = i + } + return err +} + +func Search(ctx context.Context, req model.SearchReq) ([]model.SearchNode, int64, error) { + return instance.Search(ctx, req) +} + +func Index(ctx context.Context, parent string, obj model.Obj) error { + if instance == nil { + return errs.SearchNotAvailable + } + return instance.Index(ctx, model.SearchNode{ + Parent: parent, + Name: obj.GetName(), + IsDir: obj.IsDir(), + Size: obj.GetSize(), + }) +} + +type ObjWithParent struct { + Parent string + model.Obj +} + +func BatchIndex(ctx context.Context, objs []ObjWithParent) error { + if instance == nil { + return errs.SearchNotAvailable + } + if len(objs) == 0 { + return nil + } + var searchNodes []model.SearchNode + for i := range objs { + searchNodes = append(searchNodes, model.SearchNode{ + Parent: objs[i].Parent, + Name: objs[i].GetName(), + IsDir: objs[i].IsDir(), + Size: objs[i].GetSize(), + }) + } + return instance.BatchIndex(ctx, searchNodes) +} + +func init() { + op.RegisterSettingItemHook(conf.SearchIndex, func(item *model.SettingItem) error { + log.Debugf("searcher init, mode: %s", item.Value) + return Init(item.Value) + }) +} diff --git a/internal/search/searcher/manage.go b/internal/search/searcher/manage.go new file mode 100644 index 0000000000000000000000000000000000000000..92bdd883a92c7312741648fde1a807fc050b3267 --- /dev/null +++ b/internal/search/searcher/manage.go @@ -0,0 +1,9 @@ +package searcher + +type New func() (Searcher, error) + +var NewMap = map[string]New{} + +func RegisterSearcher(config Config, searcher New) { + NewMap[config.Name] = searcher +} diff --git a/internal/search/searcher/searcher.go b/internal/search/searcher/searcher.go new file mode 100644 index 0000000000000000000000000000000000000000..6b753931429dd08e8dbf25b9d0d6edb8ba590ca4 --- /dev/null +++ b/internal/search/searcher/searcher.go @@ -0,0 +1,31 @@ +package searcher + +import ( + "context" + + "github.com/alist-org/alist/v3/internal/model" +) + +type Config struct { + Name string + AutoUpdate bool +} + +type Searcher interface { + // Config of the searcher + Config() Config + // Search specific keywords in specific path + Search(ctx context.Context, req model.SearchReq) ([]model.SearchNode, int64, error) + // Index obj with parent + Index(ctx context.Context, node model.SearchNode) error + // BatchIndex obj with parent + BatchIndex(ctx context.Context, nodes []model.SearchNode) error + // Get by parent + Get(ctx context.Context, parent string) ([]model.SearchNode, error) + // Del with prefix + Del(ctx context.Context, prefix string) error + // Release resource + Release(ctx context.Context) error + // Clear all index + Clear(ctx context.Context) error +} diff --git a/internal/search/util.go b/internal/search/util.go new file mode 100644 index 0000000000000000000000000000000000000000..8d03b740c332b8e018a15b379d60d2fb4018617a --- /dev/null +++ b/internal/search/util.go @@ -0,0 +1,96 @@ +package search + +import ( + "strings" + + "github.com/alist-org/alist/v3/drivers/alist_v3" + "github.com/alist-org/alist/v3/drivers/base" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" + log "github.com/sirupsen/logrus" +) + +func Progress() (*model.IndexProgress, error) { + p := setting.GetStr(conf.IndexProgress) + var progress model.IndexProgress + err := utils.Json.UnmarshalFromString(p, &progress) + return &progress, err +} + +func WriteProgress(progress *model.IndexProgress) { + p, err := utils.Json.MarshalToString(progress) + if err != nil { + log.Errorf("marshal progress error: %+v", err) + } + err = op.SaveSettingItem(&model.SettingItem{ + Key: conf.IndexProgress, + Value: p, + Type: conf.TypeText, + Group: model.SINGLE, + Flag: model.PRIVATE, + }) + if err != nil { + log.Errorf("save progress error: %+v", err) + } +} + +func updateIgnorePaths() { + storages := op.GetAllStorages() + ignorePaths := make([]string, 0) + var skipDrivers = []string{"AList V2", "AList V3", "Virtual"} + v3Visited := make(map[string]bool) + for _, storage := range storages { + if utils.SliceContains(skipDrivers, storage.Config().Name) { + if storage.Config().Name == "AList V3" { + addition := storage.GetAddition().(*alist_v3.Addition) + allowIndexed, visited := v3Visited[addition.Address] + if !visited { + url := addition.Address + "/api/public/settings" + res, err := base.RestyClient.R().Get(url) + if err == nil { + log.Debugf("allow_indexed body: %+v", res.String()) + allowIndexed = utils.Json.Get(res.Body(), "data", conf.AllowIndexed).ToString() == "true" + v3Visited[addition.Address] = allowIndexed + } + } + log.Debugf("%s allow_indexed: %v", addition.Address, allowIndexed) + if !allowIndexed { + ignorePaths = append(ignorePaths, storage.GetStorage().MountPath) + } + } else { + ignorePaths = append(ignorePaths, storage.GetStorage().MountPath) + } + } + } + customIgnorePaths := setting.GetStr(conf.IgnorePaths) + if customIgnorePaths != "" { + ignorePaths = append(ignorePaths, strings.Split(customIgnorePaths, "\n")...) + } + conf.SlicesMap[conf.IgnorePaths] = ignorePaths +} + +func isIgnorePath(path string) bool { + for _, ignorePath := range conf.SlicesMap[conf.IgnorePaths] { + if strings.HasPrefix(path, ignorePath) { + return true + } + } + return false +} + +func init() { + op.RegisterSettingItemHook(conf.IgnorePaths, func(item *model.SettingItem) error { + updateIgnorePaths() + return nil + }) + op.RegisterStorageHook(func(typ string, storage driver.Driver) { + var skipDrivers = []string{"AList V2", "AList V3", "Virtual"} + if utils.SliceContains(skipDrivers, storage.Config().Name) { + updateIgnorePaths() + } + }) +} diff --git a/internal/setting/setting.go b/internal/setting/setting.go new file mode 100644 index 0000000000000000000000000000000000000000..cd77874bec8b00bf87dfa6b05376cbaefbb9e757 --- /dev/null +++ b/internal/setting/setting.go @@ -0,0 +1,30 @@ +package setting + +import ( + "strconv" + + "github.com/alist-org/alist/v3/internal/op" +) + +func GetStr(key string, defaultValue ...string) string { + val, _ := op.GetSettingItemByKey(key) + if val == nil { + if len(defaultValue) > 0 { + return defaultValue[0] + } + return "" + } + return val.Value +} + +func GetInt(key string, defaultVal int) int { + i, err := strconv.Atoi(GetStr(key)) + if err != nil { + return defaultVal + } + return i +} + +func GetBool(key string) bool { + return GetStr(key) == "true" || GetStr(key) == "1" +} diff --git a/internal/sign/sign.go b/internal/sign/sign.go new file mode 100644 index 0000000000000000000000000000000000000000..978ae7cc7ad9c8f0e546c3c1f52ec08aa83e5660 --- /dev/null +++ b/internal/sign/sign.go @@ -0,0 +1,41 @@ +package sign + +import ( + "sync" + "time" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/sign" +) + +var once sync.Once +var instance sign.Sign + +func Sign(data string) string { + expire := setting.GetInt(conf.LinkExpiration, 0) + if expire == 0 { + return NotExpired(data) + } else { + return WithDuration(data, time.Duration(expire)*time.Hour) + } +} + +func WithDuration(data string, d time.Duration) string { + once.Do(Instance) + return instance.Sign(data, time.Now().Add(d).Unix()) +} + +func NotExpired(data string) string { + once.Do(Instance) + return instance.Sign(data, 0) +} + +func Verify(data string, sign string) error { + once.Do(Instance) + return instance.Verify(data, sign) +} + +func Instance() { + instance = sign.NewHMACSign([]byte(setting.GetStr(conf.Token))) +} diff --git a/internal/stream/stream.go b/internal/stream/stream.go new file mode 100644 index 0000000000000000000000000000000000000000..4b882c519e0988c09287524a345171065329ecdf --- /dev/null +++ b/internal/stream/stream.go @@ -0,0 +1,249 @@ +package stream + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "os" + + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" +) + +type FileStream struct { + Ctx context.Context + model.Obj + io.Reader + Mimetype string + WebPutAsTask bool + ForceStreamUpload bool + Exist model.Obj //the file existed in the destination, we can reuse some info since we wil overwrite it + utils.Closers + tmpFile *os.File //if present, tmpFile has full content, it will be deleted at last + peekBuff *bytes.Reader +} + +func (f *FileStream) GetSize() int64 { + if f.tmpFile != nil { + info, err := f.tmpFile.Stat() + if err == nil { + return info.Size() + } + } + return f.Obj.GetSize() +} + +func (f *FileStream) GetMimetype() string { + return f.Mimetype +} + +func (f *FileStream) NeedStore() bool { + return f.WebPutAsTask +} + +func (f *FileStream) IsForceStreamUpload() bool { + return f.ForceStreamUpload +} + +func (f *FileStream) Close() error { + var err1, err2 error + err1 = f.Closers.Close() + if f.tmpFile != nil { + err2 = os.RemoveAll(f.tmpFile.Name()) + if err2 != nil { + err2 = errs.NewErr(err2, "failed to remove tmpFile [%s]", f.tmpFile.Name()) + } + } + + return errors.Join(err1, err2) +} + +func (f *FileStream) GetExist() model.Obj { + return f.Exist +} +func (f *FileStream) SetExist(obj model.Obj) { + f.Exist = obj +} + +// CacheFullInTempFile save all data into tmpFile. Not recommended since it wears disk, +// and can't start upload until the file is written. It's not thread-safe! +func (f *FileStream) CacheFullInTempFile() (model.File, error) { + if f.tmpFile != nil { + return f.tmpFile, nil + } + if file, ok := f.Reader.(model.File); ok { + return file, nil + } + tmpF, err := utils.CreateTempFile(f.Reader, f.GetSize()) + if err != nil { + return nil, err + } + f.Add(tmpF) + f.tmpFile = tmpF + f.Reader = tmpF + return f.tmpFile, nil +} + +const InMemoryBufMaxSize = 10 // Megabytes +const InMemoryBufMaxSizeBytes = InMemoryBufMaxSize * 1024 * 1024 + +// RangeRead have to cache all data first since only Reader is provided. +// also support a peeking RangeRead at very start, but won't buffer more than 10MB data in memory +func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { + if httpRange.Length == -1 { + httpRange.Length = f.GetSize() + } + if f.peekBuff != nil && httpRange.Start < int64(f.peekBuff.Len()) && httpRange.Start+httpRange.Length-1 < int64(f.peekBuff.Len()) { + return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil + } + if f.tmpFile == nil { + if httpRange.Start == 0 && httpRange.Length <= InMemoryBufMaxSizeBytes && f.peekBuff == nil { + bufSize := utils.Min(httpRange.Length, f.GetSize()) + newBuf := bytes.NewBuffer(make([]byte, 0, bufSize)) + n, err := io.CopyN(newBuf, f.Reader, bufSize) + if err != nil { + return nil, err + } + if n != bufSize { + return nil, fmt.Errorf("stream RangeRead did not get all data in peek, expect =%d ,actual =%d", bufSize, n) + } + f.peekBuff = bytes.NewReader(newBuf.Bytes()) + f.Reader = io.MultiReader(f.peekBuff, f.Reader) + return io.NewSectionReader(f.peekBuff, httpRange.Start, httpRange.Length), nil + } else { + _, err := f.CacheFullInTempFile() + if err != nil { + return nil, err + } + } + } + return io.NewSectionReader(f.tmpFile, httpRange.Start, httpRange.Length), nil +} + +var _ model.FileStreamer = (*SeekableStream)(nil) +var _ model.FileStreamer = (*FileStream)(nil) + +//var _ seekableStream = (*FileStream)(nil) + +// for most internal stream, which is either RangeReadCloser or MFile +type SeekableStream struct { + FileStream + Link *model.Link + // should have one of belows to support rangeRead + rangeReadCloser model.RangeReadCloserIF + mFile model.File +} + +func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error) { + if len(fs.Mimetype) == 0 { + fs.Mimetype = utils.GetMimeType(fs.Obj.GetName()) + } + ss := SeekableStream{FileStream: fs, Link: link} + if ss.Reader != nil { + result, ok := ss.Reader.(model.File) + if ok { + ss.mFile = result + ss.Closers.Add(result) + return &ss, nil + } + } + if ss.Link != nil { + if ss.Link.MFile != nil { + ss.mFile = ss.Link.MFile + ss.Reader = ss.Link.MFile + ss.Closers.Add(ss.Link.MFile) + return &ss, nil + } + + if ss.Link.RangeReadCloser != nil { + ss.rangeReadCloser = ss.Link.RangeReadCloser + return &ss, nil + } + if len(ss.Link.URL) > 0 { + rrc, err := GetRangeReadCloserFromLink(ss.GetSize(), link) + if err != nil { + return nil, err + } + ss.rangeReadCloser = rrc + return &ss, nil + } + } + + return nil, fmt.Errorf("illegal seekableStream") +} + +//func (ss *SeekableStream) Peek(length int) { +// +//} + +// RangeRead is not thread-safe, pls use it in single thread only. +func (ss *SeekableStream) RangeRead(httpRange http_range.Range) (io.Reader, error) { + if httpRange.Length == -1 { + httpRange.Length = ss.GetSize() + } + if ss.mFile != nil { + return io.NewSectionReader(ss.mFile, httpRange.Start, httpRange.Length), nil + } + if ss.tmpFile != nil { + return io.NewSectionReader(ss.tmpFile, httpRange.Start, httpRange.Length), nil + } + if ss.rangeReadCloser != nil { + rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, httpRange) + if err != nil { + return nil, err + } + return rc, nil + } + return nil, fmt.Errorf("can't find mFile or rangeReadCloser") +} + +//func (f *FileStream) GetReader() io.Reader { +// return f.Reader +//} + +// only provide Reader as full stream when it's demanded. in rapid-upload, we can skip this to save memory +func (ss *SeekableStream) Read(p []byte) (n int, err error) { + //f.mu.Lock() + + //f.peekedOnce = true + //defer f.mu.Unlock() + if ss.Reader == nil { + if ss.rangeReadCloser == nil { + return 0, fmt.Errorf("illegal seekableStream") + } + rc, err := ss.rangeReadCloser.RangeRead(ss.Ctx, http_range.Range{Length: -1}) + if err != nil { + return 0, nil + } + ss.Reader = io.NopCloser(rc) + ss.Closers.Add(rc) + + } + return ss.Reader.Read(p) +} + +func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) { + if ss.tmpFile != nil { + return ss.tmpFile, nil + } + if ss.mFile != nil { + return ss.mFile, nil + } + tmpF, err := utils.CreateTempFile(ss, ss.GetSize()) + if err != nil { + return nil, err + } + ss.Add(tmpF) + ss.tmpFile = tmpF + ss.Reader = tmpF + return ss.tmpFile, nil +} + +func (f *FileStream) SetTmpFile(r *os.File) { + f.Reader = r + f.tmpFile = r +} diff --git a/internal/stream/util.go b/internal/stream/util.go new file mode 100644 index 0000000000000000000000000000000000000000..7d2b7ef75097363b54a512e1baf721491e0f0490 --- /dev/null +++ b/internal/stream/util.go @@ -0,0 +1,88 @@ +package stream + +import ( + "context" + "fmt" + "io" + "net/http" + + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/net" + "github.com/alist-org/alist/v3/pkg/http_range" + log "github.com/sirupsen/logrus" +) + +func GetRangeReadCloserFromLink(size int64, link *model.Link) (model.RangeReadCloserIF, error) { + if len(link.URL) == 0 { + return nil, fmt.Errorf("can't create RangeReadCloser since URL is empty in link") + } + //remoteClosers := utils.EmptyClosers() + rangeReaderFunc := func(ctx context.Context, r http_range.Range) (io.ReadCloser, error) { + if link.Concurrency != 0 || link.PartSize != 0 { + header := net.ProcessHeader(http.Header{}, link.Header) + down := net.NewDownloader(func(d *net.Downloader) { + d.Concurrency = link.Concurrency + d.PartSize = link.PartSize + }) + req := &net.HttpRequestParams{ + URL: link.URL, + Range: r, + Size: size, + HeaderRef: header, + } + rc, err := down.Download(ctx, req) + if err != nil { + return nil, errs.NewErr(err, "GetReadCloserFromLink failed") + } + return rc, nil + + } + if len(link.URL) > 0 { + response, err := RequestRangedHttp(ctx, link, r.Start, r.Length) + if err != nil { + if response == nil { + return nil, fmt.Errorf("http request failure, err:%s", err) + } + return nil, fmt.Errorf("http request failure,status: %d err:%s", response.StatusCode, err) + } + if r.Start == 0 && (r.Length == -1 || r.Length == size) || response.StatusCode == http.StatusPartialContent || + checkContentRange(&response.Header, r.Start) { + return response.Body, nil + } else if response.StatusCode == http.StatusOK { + log.Warnf("remote http server not supporting range request, expect low perfromace!") + readCloser, err := net.GetRangedHttpReader(response.Body, r.Start, r.Length) + if err != nil { + return nil, err + } + return readCloser, nil + + } + + return response.Body, nil + } + + return nil, errs.NotSupport + } + resultRangeReadCloser := model.RangeReadCloser{RangeReader: rangeReaderFunc} + return &resultRangeReadCloser, nil +} + +func RequestRangedHttp(ctx context.Context, link *model.Link, offset, length int64) (*http.Response, error) { + header := net.ProcessHeader(http.Header{}, link.Header) + header = http_range.ApplyRangeToHttpHeader(http_range.Range{Start: offset, Length: length}, header) + + return net.RequestHttp(ctx, "GET", header, link.URL) +} + +// 139 cloud does not properly return 206 http status code, add a hack here +func checkContentRange(header *http.Header, offset int64) bool { + start, _, err := http_range.ParseContentRange(header.Get("Content-Range")) + if err != nil { + log.Warnf("exception trying to parse Content-Range, will ignore,err=%s", err) + } + if start == offset { + return true + } + return false +} diff --git a/main.go b/main.go new file mode 100644 index 0000000000000000000000000000000000000000..ecf0a643e68a9f6798185e565dec0ea496e98c24 --- /dev/null +++ b/main.go @@ -0,0 +1,7 @@ +package main + +import "github.com/alist-org/alist/v3/cmd" + +func main() { + cmd.Execute() +} diff --git a/pkg/aria2/rpc/README.md b/pkg/aria2/rpc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a7191f5faf72b98766dc8669294ad06595a93102 --- /dev/null +++ b/pkg/aria2/rpc/README.md @@ -0,0 +1,257 @@ +# PACKAGE DOCUMENTATION + +**package rpc** + + import "github.com/matzoe/argo/rpc" + + + +## FUNCTIONS + +``` +func Call(address, method string, params, reply interface{}) error +``` + +## TYPES + +``` +type Client struct { + // contains filtered or unexported fields +} +``` + +``` +func New(uri string) *Client +``` + +``` +func (id *Client) AddMetalink(uri string, options ...interface{}) (gid string, err error) +``` +`aria2.addMetalink(metalink[, options[, position]])` This method adds Metalink download by uploading ".metalink" file. `metalink` is of type base64 which contains Base64-encoded ".metalink" file. `options` is of type struct and its members are a pair of option name and value. See Options below for more details. If `position` is given as an integer starting from 0, the new download is inserted at `position` in the +waiting queue. If `position` is not given or `position` is larger than the size of the queue, it is appended at the end of the queue. This method returns array of GID of registered download. If `--rpc-save-upload-metadata` is true, the uploaded data is saved as a file named hex string of SHA-1 hash of data plus ".metalink" in the directory specified by `--dir` option. The example of filename is 0a3893293e27ac0490424c06de4d09242215f0a6.metalink. If same file already exists, it is overwritten. If the file cannot be saved successfully or `--rpc-save-upload-metadata` is false, the downloads added by this method are not saved by `--save-session`. + +``` +func (id *Client) AddTorrent(filename string, options ...interface{}) (gid string, err error) +``` +`aria2.addTorrent(torrent[, uris[, options[, position]]])` This method adds BitTorrent download by uploading ".torrent" file. If you want to add BitTorrent Magnet URI, use `aria2.addUri()` method instead. torrent is of type base64 which contains Base64-encoded ".torrent" file. `uris` is of type array and its element is URI which is of type string. `uris` is used for Web-seeding. For single file torrents, URI can be a complete URI pointing to the resource or if URI ends with /, name in torrent file is added. For multi-file torrents, name and path in torrent are added to form a URI for each file. options is of type struct and its members are +a pair of option name and value. See Options below for more details. If `position` is given as an integer starting from 0, the new download is inserted at `position` in the waiting queue. If `position` is not given or `position` is larger than the size of the queue, it is appended at the end of the queue. This method returns GID of registered download. If `--rpc-save-upload-metadata` is true, the uploaded data is saved as a file named hex string of SHA-1 hash of data plus ".torrent" in the +directory specified by `--dir` option. The example of filename is 0a3893293e27ac0490424c06de4d09242215f0a6.torrent. If same file already exists, it is overwritten. If the file cannot be saved successfully or `--rpc-save-upload-metadata` is false, the downloads added by this method are not saved by -`-save-session`. + +``` +func (id *Client) AddUri(uri string, options ...interface{}) (gid string, err error) +``` + +`aria2.addUri(uris[, options[, position]])` This method adds new HTTP(S)/FTP/BitTorrent Magnet URI. `uris` is of type array and its element is URI which is of type string. For BitTorrent Magnet URI, `uris` must have only one element and it should be BitTorrent Magnet URI. URIs in uris must point to the same file. If you mix other URIs which point to another file, aria2 does not complain but download may +fail. `options` is of type struct and its members are a pair of option name and value. See Options below for more details. If `position` is given as an integer starting from 0, the new download is inserted at position in the waiting queue. If `position` is not given or `position` is larger than the size of the queue, it is appended at the end of the queue. This method returns GID of registered download. + +``` +func (id *Client) ChangeGlobalOption(options map[string]interface{}) (g string, err error) +``` + +`aria2.changeGlobalOption(options)` This method changes global options dynamically. `options` is of type struct. The following `options` are available: + + download-result + log + log-level + max-concurrent-downloads + max-download-result + max-overall-download-limit + max-overall-upload-limit + save-cookies + save-session + server-stat-of + +In addition to them, options listed in Input File subsection are available, except for following options: `checksum`, `index-out`, `out`, `pause` and `select-file`. Using `log` option, you can dynamically start logging or change log file. To stop logging, give empty string("") as a parameter value. Note that log file is always opened in append mode. This method returns OK for success. + +``` +func (id *Client) ChangeOption(gid string, options map[string]interface{}) (g string, err error) +``` + +`aria2.changeOption(gid, options)` This method changes options of the download denoted by `gid` dynamically. `gid` is of type string. `options` is of type struct. The following `options` are available for active downloads: + + bt-max-peers + bt-request-peer-speed-limit + bt-remove-unselected-file + force-save + max-download-limit + max-upload-limit + +For waiting or paused downloads, in addition to the above options, options listed in Input File subsection are available, except for following options: dry-run, metalink-base-uri, parameterized-uri, pause, piece-length and rpc-save-upload-metadata option. This method returns OK for success. + +``` +func (id *Client) ChangePosition(gid string, pos int, how string) (p int, err error) +``` + +`aria2.changePosition(gid, pos, how)` This method changes the position of the download denoted by `gid`. `pos` is of type integer. `how` is of type string. If `how` is `POS_SET`, it moves the download to a position relative to the beginning of the queue. If `how` is `POS_CUR`, it moves the download to a position relative to the current position. If `how` is `POS_END`, it moves the download to a position relative to the end of the queue. If the destination position is less than 0 or beyond the end +of the queue, it moves the download to the beginning or the end of the queue respectively. The response is of type integer and it is the destination position. + +``` +func (id *Client) ChangeUri(gid string, fileindex int, delUris []string, addUris []string, position ...int) (p []int, err error) +``` + +`aria2.changeUri(gid, fileIndex, delUris, addUris[, position])` This method removes URIs in `delUris` from and appends URIs in `addUris` to download denoted by gid. `delUris` and `addUris` are list of string. A download can contain multiple files and URIs are attached to each file. `fileIndex` is used to select which file to remove/attach given URIs. `fileIndex` is 1-based. `position` is used to specify where URIs are inserted in the existing waiting URI list. `position` is 0-based. When +`position` is omitted, URIs are appended to the back of the list. This method first execute removal and then addition. `position` is the `position` after URIs are removed, not the `position` when this method is called. When removing URI, if same URIs exist in download, only one of them is removed for each URI in delUris. In other words, there are three URIs http://example.org/aria2 and you want remove them all, you +have to specify (at least) 3 http://example.org/aria2 in delUris. This method returns a list which contains 2 integers. The first integer is the number of URIs deleted. The second integer is the number of URIs added. + +``` +func (id *Client) ForcePause(gid string) (g string, err error) +``` + +`aria2.forcePause(pid)` This method pauses the download denoted by `gid`. This method behaves just like aria2.pause() except that this method pauses download without any action which takes time such as contacting BitTorrent tracker. + +``` +func (id *Client) ForcePauseAll() (g string, err error) +``` + +`aria2.forcePauseAll()` This method is equal to calling `aria2.forcePause()` for every active/waiting download. This methods returns OK for success. + +``` +func (id *Client) ForceRemove(gid string) (g string, err error) +``` + +`aria2.forceRemove(gid)` This method removes the download denoted by `gid`. This method behaves just like aria2.remove() except that this method removes download without any action which takes time such as contacting BitTorrent tracker. + +``` +func (id *Client) ForceShutdown() (g string, err error) +``` + +`aria2.forceShutdown()` This method shutdowns aria2. This method behaves like `aria2.shutdown()` except that any actions which takes time such as contacting BitTorrent tracker are skipped. This method returns OK. + +``` +func (id *Client) GetFiles(gid string) (m map[string]interface{}, err error) +``` + +`aria2.getFiles(gid)` This method returns file list of the download denoted by `gid`. `gid` is of type string. + +``` +func (id *Client) GetGlobalOption() (m map[string]interface{}, err error) +``` + +`aria2.getGlobalOption()` This method returns global options. The response is of type struct. Its key is the name of option. The value type is string. Note that this method does not return options which have no default value and have not been set by the command-line options, configuration files or RPC methods. Because global options are used as a template for the options of newly added download, the response contains +keys returned by `aria2.getOption()` method. + +``` +func (id *Client) GetGlobalStat() (m map[string]interface{}, err error) +``` + +`aria2.getGlobalStat()` This method returns global statistics such as overall download and upload speed. + +``` +func (id *Client) GetOption(gid string) (m map[string]interface{}, err error) +``` + +`aria2.getOption(gid)` This method returns options of the download denoted by `gid`. The response is of type struct. Its key is the name of option. The value type is string. Note that this method does not return options which have no default value and have not been set by the command-line options, configuration files or RPC methods. + +``` +func (id *Client) GetPeers(gid string) (m []map[string]interface{}, err error) +``` + +`aria2.getPeers(gid)` This method returns peer list of the download denoted by `gid`. `gid` is of type string. This method is for BitTorrent only. + +``` +func (id *Client) GetServers(gid string) (m []map[string]interface{}, err error) +``` + +`aria2.getServers(gid)` This method returns currently connected HTTP(S)/FTP servers of the download denoted by `gid`. `gid` is of type string. + +``` +func (id *Client) GetSessionInfo() (m map[string]interface{}, err error) +``` + +`aria2.getSessionInfo()` This method returns session information. + +``` +func (id *Client) GetUris(gid string) (m map[string]interface{}, err error) +``` + +`aria2.getUris(gid)` This method returns URIs used in the download denoted by `gid`. `gid` is of type string. + +``` +func (id *Client) GetVersion() (m map[string]interface{}, err error) +``` + +`aria2.getVersion()` This method returns version of the program and the list of enabled features. + +``` +func (id *Client) Multicall(methods []map[string]interface{}) (r []interface{}, err error) +``` + +`system.multicall(methods)` This method encapsulates multiple method calls in a single request. `methods` is of type array and its element is struct. The struct contains two keys: `methodName` and `params`. `methodName` is the method name to call and `params` is array containing parameters to the method. This method returns array of responses. The element of array will either be a one-item array containing the return value of each method call or struct of fault element if an encapsulated method call fails. + +``` +func (id *Client) Pause(gid string) (g string, err error) +``` + +`aria2.pause(gid)` This method pauses the download denoted by `gid`. `gid` is of type string. The status of paused download becomes paused. If the download is active, the download is placed on the first position of waiting queue. As long as the status is paused, the download is not started. To change status to waiting, use `aria2.unpause()` method. This method returns GID of paused download. + +``` +func (id *Client) PauseAll() (g string, err error) +``` + +`aria2.pauseAll()` This method is equal to calling `aria2.pause()` for every active/waiting download. This methods returns OK for success. + +``` +func (id *Client) PurgeDownloadResult() (g string, err error) +``` + +`aria2.purgeDownloadResult()` This method purges completed/error/removed downloads to free memory. This method returns OK. + +``` +func (id *Client) Remove(gid string) (g string, err error) +``` + +`aria2.remove(gid)` This method removes the download denoted by gid. `gid` is of type string. If specified download is in progress, it is stopped at first. The status of removed download becomes removed. This method returns GID of removed download. + +``` +func (id *Client) RemoveDownloadResult(gid string) (g string, err error) +``` + +`aria2.removeDownloadResult(gid)` This method removes completed/error/removed download denoted by `gid` from memory. This method returns OK for success. + +``` +func (id *Client) Shutdown() (g string, err error) +``` + +`aria2.shutdown()` This method shutdowns aria2. This method returns OK. + +``` +func (id *Client) TellActive(keys ...string) (m []map[string]interface{}, err error) +``` + +`aria2.tellActive([keys])` This method returns the list of active downloads. The response is of type array and its element is the same struct returned by `aria2.tellStatus()` method. For `keys` parameter, please refer to `aria2.tellStatus()` method. + +``` +func (id *Client) TellStatus(gid string, keys ...string) (m map[string]interface{}, err error) +``` + +`aria2.tellStatus(gid[, keys])` This method returns download progress of the download denoted by `gid`. `gid` is of type string. `keys` is array of string. If it is specified, the response contains only keys in `keys` array. If `keys` is empty or not specified, the response contains all keys. This is useful when you just want specific keys and avoid unnecessary transfers. For example, `aria2.tellStatus("2089b05ecca3d829", ["gid", "status"])` returns `gid` and `status` key. + +``` +func (id *Client) TellStopped(offset, num int, keys ...string) (m []map[string]interface{}, err error) +``` + +`aria2.tellStopped(offset, num[, keys])` This method returns the list of stopped download. `offset` is of type integer and specifies the `offset` from the oldest download. `num` is of type integer and specifies the number of downloads to be returned. For keys parameter, please refer to `aria2.tellStatus()` method. `offset` and `num` have the same semantics as `aria2.tellWaiting()` method. The response is of type array and its element is the same struct returned by `aria2.tellStatus()` method. + +``` +func (id *Client) TellWaiting(offset, num int, keys ...string) (m []map[string]interface{}, err error) +``` +`aria2.tellWaiting(offset, num[, keys])` This method returns the list of waiting download, including paused downloads. `offset` is of type integer and specifies the `offset` from the download waiting at the front. num is of type integer and specifies the number of downloads to be returned. For keys parameter, please refer to aria2.tellStatus() method. If `offset` is a positive integer, this method returns downloads +in the range of `[offset, offset + num)`. `offset` can be a negative integer. `offset == -1` points last download in the waiting queue and `offset == -2` points the download before the last download, and so on. The downloads in the response are in reversed order. For example, imagine that three downloads "A","B" and "C" are waiting in this order. + + aria2.tellWaiting(0, 1) returns ["A"]. + aria2.tellWaiting(1, 2) returns ["B", "C"]. + aria2.tellWaiting(-1, 2) returns ["C", "B"]. + +The response is of type array and its element is the same struct returned by `aria2.tellStatus()` method. + +``` +func (id *Client) Unpause(gid string) (g string, err error) +``` + +`aria2.unpause(gid)` This method changes the status of the download denoted by `gid` from paused to waiting. This makes the download eligible to restart. `gid` is of type string. This method returns GID of unpaused download. + +``` +func (id *Client) UnpauseAll() (g string, err error) +``` + +`aria2.unpauseAll()` This method is equal to calling `aria2.unpause()` for every active/waiting download. This methods returns OK for success. diff --git a/pkg/aria2/rpc/call.go b/pkg/aria2/rpc/call.go new file mode 100644 index 0000000000000000000000000000000000000000..a2af84617ae44ae0e3215a00fe22463894771a5c --- /dev/null +++ b/pkg/aria2/rpc/call.go @@ -0,0 +1,274 @@ +package rpc + +import ( + "context" + "errors" + "net" + "net/http" + "net/url" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + log "github.com/sirupsen/logrus" +) + +type caller interface { + // Call sends a request of rpc to aria2 daemon + Call(method string, params, reply interface{}) (err error) + Close() error +} + +type httpCaller struct { + uri string + c *http.Client + cancel context.CancelFunc + wg *sync.WaitGroup + once sync.Once +} + +func newHTTPCaller(ctx context.Context, u *url.URL, timeout time.Duration, notifier Notifier) *httpCaller { + c := &http.Client{ + Transport: &http.Transport{ + MaxIdleConnsPerHost: 1, + MaxConnsPerHost: 1, + // TLSClientConfig: tlsConfig, + Dial: (&net.Dialer{ + Timeout: timeout, + KeepAlive: 60 * time.Second, + }).Dial, + TLSHandshakeTimeout: 3 * time.Second, + ResponseHeaderTimeout: timeout, + }, + } + var wg sync.WaitGroup + ctx, cancel := context.WithCancel(ctx) + h := &httpCaller{uri: u.String(), c: c, cancel: cancel, wg: &wg} + if notifier != nil { + h.setNotifier(ctx, *u, notifier) + } + return h +} + +func (h *httpCaller) Close() (err error) { + h.once.Do(func() { + h.cancel() + h.wg.Wait() + }) + return +} + +func (h *httpCaller) setNotifier(ctx context.Context, u url.URL, notifier Notifier) (err error) { + u.Scheme = "ws" + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + return + } + h.wg.Add(1) + go func() { + defer h.wg.Done() + defer conn.Close() + select { + case <-ctx.Done(): + conn.SetWriteDeadline(time.Now().Add(time.Second)) + if err := conn.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil { + log.Printf("sending websocket close message: %v", err) + } + return + } + }() + h.wg.Add(1) + go func() { + defer h.wg.Done() + var request websocketResponse + var err error + for { + select { + case <-ctx.Done(): + return + default: + } + if err = conn.ReadJSON(&request); err != nil { + select { + case <-ctx.Done(): + return + default: + } + log.Printf("conn.ReadJSON|err:%v", err.Error()) + return + } + switch request.Method { + case "aria2.onDownloadStart": + notifier.OnDownloadStart(request.Params) + case "aria2.onDownloadPause": + notifier.OnDownloadPause(request.Params) + case "aria2.onDownloadStop": + notifier.OnDownloadStop(request.Params) + case "aria2.onDownloadComplete": + notifier.OnDownloadComplete(request.Params) + case "aria2.onDownloadError": + notifier.OnDownloadError(request.Params) + case "aria2.onBtDownloadComplete": + notifier.OnBtDownloadComplete(request.Params) + default: + log.Printf("unexpected notification: %s", request.Method) + } + } + }() + return +} + +func (h httpCaller) Call(method string, params, reply interface{}) (err error) { + payload, err := EncodeClientRequest(method, params) + if err != nil { + return + } + r, err := h.c.Post(h.uri, "application/json", payload) + if err != nil { + return + } + err = DecodeClientResponse(r.Body, &reply) + r.Body.Close() + return +} + +type websocketCaller struct { + conn *websocket.Conn + sendChan chan *sendRequest + cancel context.CancelFunc + wg *sync.WaitGroup + once sync.Once + timeout time.Duration +} + +func newWebsocketCaller(ctx context.Context, uri string, timeout time.Duration, notifier Notifier) (*websocketCaller, error) { + var header = http.Header{} + conn, _, err := websocket.DefaultDialer.Dial(uri, header) + if err != nil { + return nil, err + } + + sendChan := make(chan *sendRequest, 16) + var wg sync.WaitGroup + ctx, cancel := context.WithCancel(ctx) + w := &websocketCaller{conn: conn, wg: &wg, cancel: cancel, sendChan: sendChan, timeout: timeout} + processor := NewResponseProcessor() + wg.Add(1) + go func() { // routine:recv + defer wg.Done() + defer cancel() + for { + select { + case <-ctx.Done(): + return + default: + } + var resp websocketResponse + if err := conn.ReadJSON(&resp); err != nil { + select { + case <-ctx.Done(): + return + default: + } + log.Printf("conn.ReadJSON|err:%v", err.Error()) + return + } + if resp.Id == nil { // RPC notifications + if notifier != nil { + switch resp.Method { + case "aria2.onDownloadStart": + notifier.OnDownloadStart(resp.Params) + case "aria2.onDownloadPause": + notifier.OnDownloadPause(resp.Params) + case "aria2.onDownloadStop": + notifier.OnDownloadStop(resp.Params) + case "aria2.onDownloadComplete": + notifier.OnDownloadComplete(resp.Params) + case "aria2.onDownloadError": + notifier.OnDownloadError(resp.Params) + case "aria2.onBtDownloadComplete": + notifier.OnBtDownloadComplete(resp.Params) + default: + log.Printf("unexpected notification: %s", resp.Method) + } + } + continue + } + processor.Process(resp.clientResponse) + } + }() + wg.Add(1) + go func() { // routine:send + defer wg.Done() + defer cancel() + defer w.conn.Close() + + for { + select { + case <-ctx.Done(): + if err := w.conn.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil { + log.Printf("sending websocket close message: %v", err) + } + return + case req := <-sendChan: + processor.Add(req.request.Id, func(resp clientResponse) error { + err := resp.decode(req.reply) + req.cancel() + return err + }) + w.conn.SetWriteDeadline(time.Now().Add(timeout)) + w.conn.WriteJSON(req.request) + } + } + }() + + return w, nil +} + +func (w *websocketCaller) Close() (err error) { + w.once.Do(func() { + w.cancel() + w.wg.Wait() + }) + return +} + +func (w websocketCaller) Call(method string, params, reply interface{}) (err error) { + ctx, cancel := context.WithTimeout(context.Background(), w.timeout) + defer cancel() + select { + case w.sendChan <- &sendRequest{cancel: cancel, request: &clientRequest{ + Version: "2.0", + Method: method, + Params: params, + Id: reqid(), + }, reply: reply}: + + default: + return errors.New("sending channel blocking") + } + + select { + case <-ctx.Done(): + if err := ctx.Err(); err == context.DeadlineExceeded { + return err + } + } + return +} + +type sendRequest struct { + cancel context.CancelFunc + request *clientRequest + reply interface{} +} + +var reqid = func() func() uint64 { + var id = uint64(time.Now().UnixNano()) + return func() uint64 { + return atomic.AddUint64(&id, 1) + } +}() diff --git a/pkg/aria2/rpc/call_test.go b/pkg/aria2/rpc/call_test.go new file mode 100644 index 0000000000000000000000000000000000000000..64d25200e40d2792b8eb2b4e9eb5ba4f638be9f7 --- /dev/null +++ b/pkg/aria2/rpc/call_test.go @@ -0,0 +1,23 @@ +package rpc + +import ( + "context" + "testing" + "time" +) + +func TestWebsocketCaller(t *testing.T) { + time.Sleep(time.Second) + c, err := newWebsocketCaller(context.Background(), "ws://localhost:6800/jsonrpc", time.Second, &DummyNotifier{}) + if err != nil { + t.Fatal(err.Error()) + } + defer c.Close() + + var info VersionInfo + if err := c.Call(aria2GetVersion, []interface{}{}, &info); err != nil { + t.Error(err.Error()) + } else { + println(info.Version) + } +} diff --git a/pkg/aria2/rpc/client.go b/pkg/aria2/rpc/client.go new file mode 100644 index 0000000000000000000000000000000000000000..041e9d4ffd33f73fd0ab102239041c892b7e3fef --- /dev/null +++ b/pkg/aria2/rpc/client.go @@ -0,0 +1,665 @@ +package rpc + +import ( + "context" + "encoding/base64" + "errors" + "net/url" + "os" + "time" +) + +// Option is a container for specifying Call parameters and returning results +type Option map[string]interface{} + +type Client interface { + Protocol + Close() error +} + +type client struct { + caller + url *url.URL + token string +} + +var ( + errInvalidParameter = errors.New("invalid parameter") + errNotImplemented = errors.New("not implemented") + errConnTimeout = errors.New("connect to aria2 daemon timeout") +) + +// New returns an instance of Client +func New(ctx context.Context, uri string, token string, timeout time.Duration, notifier Notifier) (Client, error) { + u, err := url.Parse(uri) + if err != nil { + return nil, err + } + var caller caller + switch u.Scheme { + case "http", "https": + caller = newHTTPCaller(ctx, u, timeout, notifier) + case "ws", "wss": + caller, err = newWebsocketCaller(ctx, u.String(), timeout, notifier) + if err != nil { + return nil, err + } + default: + return nil, errInvalidParameter + } + c := &client{caller: caller, url: u, token: token} + return c, nil +} + +// `aria2.addUri([secret, ]uris[, options[, position]])` +// This method adds a new download. uris is an array of HTTP/FTP/SFTP/BitTorrent URIs (strings) pointing to the same resource. +// If you mix URIs pointing to different resources, then the download may fail or be corrupted without aria2 complaining. +// When adding BitTorrent Magnet URIs, uris must have only one element and it should be BitTorrent Magnet URI. +// options is a struct and its members are pairs of option name and value. +// If position is given, it must be an integer starting from 0. +// The new download will be inserted at position in the waiting queue. +// If position is omitted or position is larger than the current size of the queue, the new download is appended to the end of the queue. +// This method returns the GID of the newly registered download. +func (c *client) AddURI(uris []string, options ...interface{}) (gid string, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, uris) + if options != nil { + params = append(params, options...) + } + err = c.Call(aria2AddURI, params, &gid) + return +} + +// `aria2.addTorrent([secret, ]torrent[, uris[, options[, position]]])` +// This method adds a BitTorrent download by uploading a ".torrent" file. +// If you want to add a BitTorrent Magnet URI, use the aria2.addUri() method instead. +// torrent must be a base64-encoded string containing the contents of the ".torrent" file. +// uris is an array of URIs (string). uris is used for Web-seeding. +// For single file torrents, the URI can be a complete URI pointing to the resource; if URI ends with /, name in torrent file is added. +// For multi-file torrents, name and path in torrent are added to form a URI for each file. options is a struct and its members are pairs of option name and value. +// If position is given, it must be an integer starting from 0. +// The new download will be inserted at position in the waiting queue. +// If position is omitted or position is larger than the current size of the queue, the new download is appended to the end of the queue. +// This method returns the GID of the newly registered download. +// If --rpc-save-upload-metadata is true, the uploaded data is saved as a file named as the hex string of SHA-1 hash of data plus ".torrent" in the directory specified by --dir option. +// E.g. a file name might be 0a3893293e27ac0490424c06de4d09242215f0a6.torrent. +// If a file with the same name already exists, it is overwritten! +// If the file cannot be saved successfully or --rpc-save-upload-metadata is false, the downloads added by this method are not saved by --save-session. +func (c *client) AddTorrent(filename string, options ...interface{}) (gid string, err error) { + co, err := os.ReadFile(filename) + if err != nil { + return + } + file := base64.StdEncoding.EncodeToString(co) + params := make([]interface{}, 0, 3) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, file) + params = append(params, []interface{}{}) + if options != nil { + params = append(params, options...) + } + err = c.Call(aria2AddTorrent, params, &gid) + return +} + +// `aria2.addMetalink([secret, ]metalink[, options[, position]])` +// This method adds a Metalink download by uploading a ".metalink" file. +// metalink is a base64-encoded string which contains the contents of the ".metalink" file. +// options is a struct and its members are pairs of option name and value. +// If position is given, it must be an integer starting from 0. +// The new download will be inserted at position in the waiting queue. +// If position is omitted or position is larger than the current size of the queue, the new download is appended to the end of the queue. +// This method returns an array of GIDs of newly registered downloads. +// If --rpc-save-upload-metadata is true, the uploaded data is saved as a file named hex string of SHA-1 hash of data plus ".metalink" in the directory specified by --dir option. +// E.g. a file name might be 0a3893293e27ac0490424c06de4d09242215f0a6.metalink. +// If a file with the same name already exists, it is overwritten! +// If the file cannot be saved successfully or --rpc-save-upload-metadata is false, the downloads added by this method are not saved by --save-session. +func (c *client) AddMetalink(filename string, options ...interface{}) (gid []string, err error) { + co, err := os.ReadFile(filename) + if err != nil { + return + } + file := base64.StdEncoding.EncodeToString(co) + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, file) + if options != nil { + params = append(params, options...) + } + err = c.Call(aria2AddMetalink, params, &gid) + return +} + +// `aria2.remove([secret, ]gid)` +// This method removes the download denoted by gid (string). +// If the specified download is in progress, it is first stopped. +// The status of the removed download becomes removed. +// This method returns GID of removed download. +func (c *client) Remove(gid string) (g string, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + err = c.Call(aria2Remove, params, &g) + return +} + +// `aria2.forceRemove([secret, ]gid)` +// This method removes the download denoted by gid. +// This method behaves just like aria2.remove() except that this method removes the download without performing any actions which take time, such as contacting BitTorrent trackers to unregister the download first. +func (c *client) ForceRemove(gid string) (g string, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + err = c.Call(aria2ForceRemove, params, &g) + return +} + +// `aria2.pause([secret, ]gid)` +// This method pauses the download denoted by gid (string). +// The status of paused download becomes paused. +// If the download was active, the download is placed in the front of waiting queue. +// While the status is paused, the download is not started. +// To change status to waiting, use the aria2.unpause() method. +// This method returns GID of paused download. +func (c *client) Pause(gid string) (g string, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + err = c.Call(aria2Pause, params, &g) + return +} + +// `aria2.pauseAll([secret])` +// This method is equal to calling aria2.pause() for every active/waiting download. +// This methods returns OK. +func (c *client) PauseAll() (ok string, err error) { + params := []string{} + if c.token != "" { + params = append(params, "token:"+c.token) + } + err = c.Call(aria2PauseAll, params, &ok) + return +} + +// `aria2.forcePause([secret, ]gid)` +// This method pauses the download denoted by gid. +// This method behaves just like aria2.pause() except that this method pauses downloads without performing any actions which take time, such as contacting BitTorrent trackers to unregister the download first. +func (c *client) ForcePause(gid string) (g string, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + err = c.Call(aria2ForcePause, params, &g) + return +} + +// `aria2.forcePauseAll([secret])` +// This method is equal to calling aria2.forcePause() for every active/waiting download. +// This methods returns OK. +func (c *client) ForcePauseAll() (ok string, err error) { + params := []string{} + if c.token != "" { + params = append(params, "token:"+c.token) + } + err = c.Call(aria2ForcePauseAll, params, &ok) + return +} + +// `aria2.unpause([secret, ]gid)` +// This method changes the status of the download denoted by gid (string) from paused to waiting, making the download eligible to be restarted. +// This method returns the GID of the unpaused download. +func (c *client) Unpause(gid string) (g string, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + err = c.Call(aria2Unpause, params, &g) + return +} + +// `aria2.unpauseAll([secret])` +// This method is equal to calling aria2.unpause() for every active/waiting download. +// This methods returns OK. +func (c *client) UnpauseAll() (ok string, err error) { + params := []string{} + if c.token != "" { + params = append(params, "token:"+c.token) + } + err = c.Call(aria2UnpauseAll, params, &ok) + return +} + +// `aria2.tellStatus([secret, ]gid[, keys])` +// This method returns the progress of the download denoted by gid (string). +// keys is an array of strings. +// If specified, the response contains only keys in the keys array. +// If keys is empty or omitted, the response contains all keys. +// This is useful when you just want specific keys and avoid unnecessary transfers. +// For example, aria2.tellStatus("2089b05ecca3d829", ["gid", "status"]) returns the gid and status keys only. +// The response is a struct and contains following keys. Values are strings. +// https://aria2.github.io/manual/en/html/aria2c.html#aria2.tellStatus +func (c *client) TellStatus(gid string, keys ...string) (info StatusInfo, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + if keys != nil { + params = append(params, keys) + } + err = c.Call(aria2TellStatus, params, &info) + return +} + +// `aria2.getUris([secret, ]gid)` +// This method returns the URIs used in the download denoted by gid (string). +// The response is an array of structs and it contains following keys. Values are string. +// +// uri URI +// status 'used' if the URI is in use. 'waiting' if the URI is still waiting in the queue. +func (c *client) GetURIs(gid string) (infos []URIInfo, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + err = c.Call(aria2GetURIs, params, &infos) + return +} + +// `aria2.getFiles([secret, ]gid)` +// This method returns the file list of the download denoted by gid (string). +// The response is an array of structs which contain following keys. Values are strings. +// https://aria2.github.io/manual/en/html/aria2c.html#aria2.getFiles +func (c *client) GetFiles(gid string) (infos []FileInfo, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + err = c.Call(aria2GetFiles, params, &infos) + return +} + +// `aria2.getPeers([secret, ]gid)` +// This method returns a list peers of the download denoted by gid (string). +// This method is for BitTorrent only. +// The response is an array of structs and contains the following keys. Values are strings. +// https://aria2.github.io/manual/en/html/aria2c.html#aria2.getPeers +func (c *client) GetPeers(gid string) (infos []PeerInfo, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + err = c.Call(aria2GetPeers, params, &infos) + return +} + +// `aria2.getServers([secret, ]gid)` +// This method returns currently connected HTTP(S)/FTP/SFTP servers of the download denoted by gid (string). +// The response is an array of structs and contains the following keys. Values are strings. +// https://aria2.github.io/manual/en/html/aria2c.html#aria2.getServers +func (c *client) GetServers(gid string) (infos []ServerInfo, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + err = c.Call(aria2GetServers, params, &infos) + return +} + +// `aria2.tellActive([secret][, keys])` +// This method returns a list of active downloads. +// The response is an array of the same structs as returned by the aria2.tellStatus() method. +// For the keys parameter, please refer to the aria2.tellStatus() method. +func (c *client) TellActive(keys ...string) (infos []StatusInfo, err error) { + params := make([]interface{}, 0, 1) + if c.token != "" { + params = append(params, "token:"+c.token) + } + if keys != nil { + params = append(params, keys) + } + err = c.Call(aria2TellActive, params, &infos) + return +} + +// `aria2.tellWaiting([secret, ]offset, num[, keys])` +// This method returns a list of waiting downloads, including paused ones. +// offset is an integer and specifies the offset from the download waiting at the front. +// num is an integer and specifies the max. number of downloads to be returned. +// For the keys parameter, please refer to the aria2.tellStatus() method. +// If offset is a positive integer, this method returns downloads in the range of [offset, offset + num). +// offset can be a negative integer. offset == -1 points last download in the waiting queue and offset == -2 points the download before the last download, and so on. +// Downloads in the response are in reversed order then. +// For example, imagine three downloads "A","B" and "C" are waiting in this order. +// aria2.tellWaiting(0, 1) returns ["A"]. +// aria2.tellWaiting(1, 2) returns ["B", "C"]. +// aria2.tellWaiting(-1, 2) returns ["C", "B"]. +// The response is an array of the same structs as returned by aria2.tellStatus() method. +func (c *client) TellWaiting(offset, num int, keys ...string) (infos []StatusInfo, err error) { + params := make([]interface{}, 0, 3) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, offset) + params = append(params, num) + if keys != nil { + params = append(params, keys) + } + err = c.Call(aria2TellWaiting, params, &infos) + return +} + +// `aria2.tellStopped([secret, ]offset, num[, keys])` +// This method returns a list of stopped downloads. +// offset is an integer and specifies the offset from the least recently stopped download. +// num is an integer and specifies the max. number of downloads to be returned. +// For the keys parameter, please refer to the aria2.tellStatus() method. +// offset and num have the same semantics as described in the aria2.tellWaiting() method. +// The response is an array of the same structs as returned by the aria2.tellStatus() method. +func (c *client) TellStopped(offset, num int, keys ...string) (infos []StatusInfo, err error) { + params := make([]interface{}, 0, 3) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, offset) + params = append(params, num) + if keys != nil { + params = append(params, keys) + } + err = c.Call(aria2TellStopped, params, &infos) + return +} + +// `aria2.changePosition([secret, ]gid, pos, how)` +// This method changes the position of the download denoted by gid in the queue. +// pos is an integer. how is a string. +// If how is POS_SET, it moves the download to a position relative to the beginning of the queue. +// If how is POS_CUR, it moves the download to a position relative to the current position. +// If how is POS_END, it moves the download to a position relative to the end of the queue. +// If the destination position is less than 0 or beyond the end of the queue, it moves the download to the beginning or the end of the queue respectively. +// The response is an integer denoting the resulting position. +// For example, if GID#2089b05ecca3d829 is currently in position 3, aria2.changePosition('2089b05ecca3d829', -1, 'POS_CUR') will change its position to 2. Additionally aria2.changePosition('2089b05ecca3d829', 0, 'POS_SET') will change its position to 0 (the beginning of the queue). +func (c *client) ChangePosition(gid string, pos int, how string) (p int, err error) { + params := make([]interface{}, 0, 3) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + params = append(params, pos) + params = append(params, how) + err = c.Call(aria2ChangePosition, params, &p) + return +} + +// `aria2.changeUri([secret, ]gid, fileIndex, delUris, addUris[, position])` +// This method removes the URIs in delUris from and appends the URIs in addUris to download denoted by gid. +// delUris and addUris are lists of strings. +// A download can contain multiple files and URIs are attached to each file. +// fileIndex is used to select which file to remove/attach given URIs. fileIndex is 1-based. +// position is used to specify where URIs are inserted in the existing waiting URI list. position is 0-based. +// When position is omitted, URIs are appended to the back of the list. +// This method first executes the removal and then the addition. +// position is the position after URIs are removed, not the position when this method is called. +// When removing an URI, if the same URIs exist in download, only one of them is removed for each URI in delUris. +// In other words, if there are three URIs http://example.org/aria2 and you want remove them all, you have to specify (at least) 3 http://example.org/aria2 in delUris. +// This method returns a list which contains two integers. +// The first integer is the number of URIs deleted. +// The second integer is the number of URIs added. +func (c *client) ChangeURI(gid string, fileindex int, delUris []string, addUris []string, position ...int) (p []int, err error) { + params := make([]interface{}, 0, 5) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + params = append(params, fileindex) + params = append(params, delUris) + params = append(params, addUris) + if position != nil { + params = append(params, position[0]) + } + err = c.Call(aria2ChangeURI, params, &p) + return +} + +// `aria2.getOption([secret, ]gid)` +// This method returns options of the download denoted by gid. +// The response is a struct where keys are the names of options. +// The values are strings. +// Note that this method does not return options which have no default value and have not been set on the command-line, in configuration files or RPC methods. +func (c *client) GetOption(gid string) (m Option, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + err = c.Call(aria2GetOption, params, &m) + return +} + +// `aria2.changeOption([secret, ]gid, options)` +// This method changes options of the download denoted by gid (string) dynamically. options is a struct. +// The following options are available for active downloads: +// +// bt-max-peers +// bt-request-peer-speed-limit +// bt-remove-unselected-file +// force-save +// max-download-limit +// max-upload-limit +// +// For waiting or paused downloads, in addition to the above options, options listed in Input File subsection are available, except for following options: dry-run, metalink-base-uri, parameterized-uri, pause, piece-length and rpc-save-upload-metadata option. +// This method returns OK for success. +func (c *client) ChangeOption(gid string, option Option) (ok string, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + if option != nil { + params = append(params, option) + } + err = c.Call(aria2ChangeOption, params, &ok) + return +} + +// `aria2.getGlobalOption([secret])` +// This method returns the global options. +// The response is a struct. +// Its keys are the names of options. +// Values are strings. +// Note that this method does not return options which have no default value and have not been set on the command-line, in configuration files or RPC methods. Because global options are used as a template for the options of newly added downloads, the response contains keys returned by the aria2.getOption() method. +func (c *client) GetGlobalOption() (m Option, err error) { + params := []string{} + if c.token != "" { + params = append(params, "token:"+c.token) + } + err = c.Call(aria2GetGlobalOption, params, &m) + return +} + +// `aria2.changeGlobalOption([secret, ]options)` +// This method changes global options dynamically. +// options is a struct. +// The following options are available: +// +// bt-max-open-files +// download-result +// log +// log-level +// max-concurrent-downloads +// max-download-result +// max-overall-download-limit +// max-overall-upload-limit +// save-cookies +// save-session +// server-stat-of +// +// In addition, options listed in the Input File subsection are available, except for following options: checksum, index-out, out, pause and select-file. +// With the log option, you can dynamically start logging or change log file. +// To stop logging, specify an empty string("") as the parameter value. +// Note that log file is always opened in append mode. +// This method returns OK for success. +func (c *client) ChangeGlobalOption(options Option) (ok string, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, options) + err = c.Call(aria2ChangeGlobalOption, params, &ok) + return +} + +// `aria2.getGlobalStat([secret])` +// This method returns global statistics such as the overall download and upload speeds. +// The response is a struct and contains the following keys. Values are strings. +// +// downloadSpeed Overall download speed (byte/sec). +// uploadSpeed Overall upload speed(byte/sec). +// numActive The number of active downloads. +// numWaiting The number of waiting downloads. +// numStopped The number of stopped downloads in the current session. +// This value is capped by the --max-download-result option. +// numStoppedTotal The number of stopped downloads in the current session and not capped by the --max-download-result option. +func (c *client) GetGlobalStat() (info GlobalStatInfo, err error) { + params := []string{} + if c.token != "" { + params = append(params, "token:"+c.token) + } + err = c.Call(aria2GetGlobalStat, params, &info) + return +} + +// `aria2.purgeDownloadResult([secret])` +// This method purges completed/error/removed downloads to free memory. +// This method returns OK. +func (c *client) PurgeDownloadResult() (ok string, err error) { + params := []string{} + if c.token != "" { + params = append(params, "token:"+c.token) + } + err = c.Call(aria2PurgeDownloadResult, params, &ok) + return +} + +// `aria2.removeDownloadResult([secret, ]gid)` +// This method removes a completed/error/removed download denoted by gid from memory. +// This method returns OK for success. +func (c *client) RemoveDownloadResult(gid string) (ok string, err error) { + params := make([]interface{}, 0, 2) + if c.token != "" { + params = append(params, "token:"+c.token) + } + params = append(params, gid) + err = c.Call(aria2RemoveDownloadResult, params, &ok) + return +} + +// `aria2.getVersion([secret])` +// This method returns the version of aria2 and the list of enabled features. +// The response is a struct and contains following keys. +// +// version Version number of aria2 as a string. +// enabledFeatures List of enabled features. Each feature is given as a string. +func (c *client) GetVersion() (info VersionInfo, err error) { + params := []string{} + if c.token != "" { + params = append(params, "token:"+c.token) + } + err = c.Call(aria2GetVersion, params, &info) + return +} + +// `aria2.getSessionInfo([secret])` +// This method returns session information. +// The response is a struct and contains following key. +// +// sessionId Session ID, which is generated each time when aria2 is invoked. +func (c *client) GetSessionInfo() (info SessionInfo, err error) { + params := []string{} + if c.token != "" { + params = append(params, "token:"+c.token) + } + err = c.Call(aria2GetSessionInfo, params, &info) + return +} + +// `aria2.shutdown([secret])` +// This method shutdowns aria2. +// This method returns OK. +func (c *client) Shutdown() (ok string, err error) { + params := []string{} + if c.token != "" { + params = append(params, "token:"+c.token) + } + err = c.Call(aria2Shutdown, params, &ok) + return +} + +// `aria2.forceShutdown([secret])` +// This method shuts down aria2(). +// This method behaves like :func:'aria2.shutdown` without performing any actions which take time, such as contacting BitTorrent trackers to unregister downloads first. +// This method returns OK. +func (c *client) ForceShutdown() (ok string, err error) { + params := []string{} + if c.token != "" { + params = append(params, "token:"+c.token) + } + err = c.Call(aria2ForceShutdown, params, &ok) + return +} + +// `aria2.saveSession([secret])` +// This method saves the current session to a file specified by the --save-session option. +// This method returns OK if it succeeds. +func (c *client) SaveSession() (ok string, err error) { + params := []string{} + if c.token != "" { + params = append(params, "token:"+c.token) + } + err = c.Call(aria2SaveSession, params, &ok) + return +} + +// `system.multicall(methods)` +// This methods encapsulates multiple method calls in a single request. +// methods is an array of structs. +// The structs contain two keys: methodName and params. +// methodName is the method name to call and params is array containing parameters to the method call. +// This method returns an array of responses. +// The elements will be either a one-item array containing the return value of the method call or a struct of fault element if an encapsulated method call fails. +func (c *client) Multicall(methods []Method) (r []interface{}, err error) { + if len(methods) == 0 { + err = errInvalidParameter + return + } + err = c.Call(aria2Multicall, []interface{}{methods}, &r) + return +} + +// `system.listMethods()` +// This method returns the all available RPC methods in an array of string. +// Unlike other methods, this method does not require secret token. +// This is safe because this method just returns the available method names. +func (c *client) ListMethods() (methods []string, err error) { + err = c.Call(aria2ListMethods, []string{}, &methods) + return +} diff --git a/pkg/aria2/rpc/client_test.go b/pkg/aria2/rpc/client_test.go new file mode 100644 index 0000000000000000000000000000000000000000..512363b4875310416e44ad79504f0f8ec84ffe97 --- /dev/null +++ b/pkg/aria2/rpc/client_test.go @@ -0,0 +1,125 @@ +package rpc + +import ( + "context" + "testing" + "time" +) + +func TestHTTPAll(t *testing.T) { + const targetURL = "https://nodejs.org/dist/index.json" + rpc, err := New(context.Background(), "http://localhost:6800/jsonrpc", "", time.Second, &DummyNotifier{}) + if err != nil { + t.Fatal(err) + } + defer rpc.Close() + g, err := rpc.AddURI([]string{targetURL}) + if err != nil { + t.Fatal(err) + } + println(g) + if _, err = rpc.TellActive(); err != nil { + t.Error(err) + } + if _, err = rpc.PauseAll(); err != nil { + t.Error(err) + } + if _, err = rpc.TellStatus(g); err != nil { + t.Error(err) + } + if _, err = rpc.GetURIs(g); err != nil { + t.Error(err) + } + if _, err = rpc.GetFiles(g); err != nil { + t.Error(err) + } + if _, err = rpc.GetPeers(g); err != nil { + t.Error(err) + } + if _, err = rpc.TellActive(); err != nil { + t.Error(err) + } + if _, err = rpc.TellWaiting(0, 1); err != nil { + t.Error(err) + } + if _, err = rpc.TellStopped(0, 1); err != nil { + t.Error(err) + } + if _, err = rpc.GetOption(g); err != nil { + t.Error(err) + } + if _, err = rpc.GetGlobalOption(); err != nil { + t.Error(err) + } + if _, err = rpc.GetGlobalStat(); err != nil { + t.Error(err) + } + if _, err = rpc.GetSessionInfo(); err != nil { + t.Error(err) + } + if _, err = rpc.Remove(g); err != nil { + t.Error(err) + } + if _, err = rpc.TellActive(); err != nil { + t.Error(err) + } +} + +func TestWebsocketAll(t *testing.T) { + const targetURL = "https://nodejs.org/dist/index.json" + rpc, err := New(context.Background(), "ws://localhost:6800/jsonrpc", "", time.Second, &DummyNotifier{}) + if err != nil { + t.Fatal(err) + } + defer rpc.Close() + g, err := rpc.AddURI([]string{targetURL}) + if err != nil { + t.Fatal(err) + } + println(g) + if _, err = rpc.TellActive(); err != nil { + t.Error(err) + } + if _, err = rpc.PauseAll(); err != nil { + t.Error(err) + } + if _, err = rpc.TellStatus(g); err != nil { + t.Error(err) + } + if _, err = rpc.GetURIs(g); err != nil { + t.Error(err) + } + if _, err = rpc.GetFiles(g); err != nil { + t.Error(err) + } + if _, err = rpc.GetPeers(g); err != nil { + t.Error(err) + } + if _, err = rpc.TellActive(); err != nil { + t.Error(err) + } + if _, err = rpc.TellWaiting(0, 1); err != nil { + t.Error(err) + } + if _, err = rpc.TellStopped(0, 1); err != nil { + t.Error(err) + } + if _, err = rpc.GetOption(g); err != nil { + t.Error(err) + } + if _, err = rpc.GetGlobalOption(); err != nil { + t.Error(err) + } + if _, err = rpc.GetGlobalStat(); err != nil { + t.Error(err) + } + if _, err = rpc.GetSessionInfo(); err != nil { + t.Error(err) + } + if _, err = rpc.Remove(g); err != nil { + t.Error(err) + } + if _, err = rpc.TellActive(); err != nil { + t.Error(err) + } +} diff --git a/pkg/aria2/rpc/const.go b/pkg/aria2/rpc/const.go new file mode 100644 index 0000000000000000000000000000000000000000..b5d83dd8508ad8ffd9148a469ecc32378952bc67 --- /dev/null +++ b/pkg/aria2/rpc/const.go @@ -0,0 +1,39 @@ +package rpc + +const ( + aria2AddURI = "aria2.addUri" + aria2AddTorrent = "aria2.addTorrent" + aria2AddMetalink = "aria2.addMetalink" + aria2Remove = "aria2.remove" + aria2ForceRemove = "aria2.forceRemove" + aria2Pause = "aria2.pause" + aria2PauseAll = "aria2.pauseAll" + aria2ForcePause = "aria2.forcePause" + aria2ForcePauseAll = "aria2.forcePauseAll" + aria2Unpause = "aria2.unpause" + aria2UnpauseAll = "aria2.unpauseAll" + aria2TellStatus = "aria2.tellStatus" + aria2GetURIs = "aria2.getUris" + aria2GetFiles = "aria2.getFiles" + aria2GetPeers = "aria2.getPeers" + aria2GetServers = "aria2.getServers" + aria2TellActive = "aria2.tellActive" + aria2TellWaiting = "aria2.tellWaiting" + aria2TellStopped = "aria2.tellStopped" + aria2ChangePosition = "aria2.changePosition" + aria2ChangeURI = "aria2.changeUri" + aria2GetOption = "aria2.getOption" + aria2ChangeOption = "aria2.changeOption" + aria2GetGlobalOption = "aria2.getGlobalOption" + aria2ChangeGlobalOption = "aria2.changeGlobalOption" + aria2GetGlobalStat = "aria2.getGlobalStat" + aria2PurgeDownloadResult = "aria2.purgeDownloadResult" + aria2RemoveDownloadResult = "aria2.removeDownloadResult" + aria2GetVersion = "aria2.getVersion" + aria2GetSessionInfo = "aria2.getSessionInfo" + aria2Shutdown = "aria2.shutdown" + aria2ForceShutdown = "aria2.forceShutdown" + aria2SaveSession = "aria2.saveSession" + aria2Multicall = "system.multicall" + aria2ListMethods = "system.listMethods" +) diff --git a/pkg/aria2/rpc/json2.go b/pkg/aria2/rpc/json2.go new file mode 100644 index 0000000000000000000000000000000000000000..3febf7eaa614ded0e4e318b1e663647154a4406b --- /dev/null +++ b/pkg/aria2/rpc/json2.go @@ -0,0 +1,116 @@ +package rpc + +// based on "github.com/gorilla/rpc/v2/json2" + +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright 2012 The Gorilla Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +import ( + "bytes" + "encoding/json" + "errors" + "io" +) + +// ---------------------------------------------------------------------------- +// Request and Response +// ---------------------------------------------------------------------------- + +// clientRequest represents a JSON-RPC request sent by a client. +type clientRequest struct { + // JSON-RPC protocol. + Version string `json:"jsonrpc"` + + // A String containing the name of the method to be invoked. + Method string `json:"method"` + + // Object to pass as request parameter to the method. + Params interface{} `json:"params"` + + // The request id. This can be of any type. It is used to match the + // response with the request that it is replying to. + Id uint64 `json:"id"` +} + +// clientResponse represents a JSON-RPC response returned to a client. +type clientResponse struct { + Version string `json:"jsonrpc"` + Result *json.RawMessage `json:"result"` + Error *json.RawMessage `json:"error"` + Id *uint64 `json:"id"` +} + +// EncodeClientRequest encodes parameters for a JSON-RPC client request. +func EncodeClientRequest(method string, args interface{}) (*bytes.Buffer, error) { + var buf bytes.Buffer + c := &clientRequest{ + Version: "2.0", + Method: method, + Params: args, + Id: reqid(), + } + if err := json.NewEncoder(&buf).Encode(c); err != nil { + return nil, err + } + return &buf, nil +} + +func (c clientResponse) decode(reply interface{}) error { + if c.Error != nil { + jsonErr := &Error{} + if err := json.Unmarshal(*c.Error, jsonErr); err != nil { + return &Error{ + Code: E_SERVER, + Message: string(*c.Error), + } + } + return jsonErr + } + + if c.Result == nil { + return ErrNullResult + } + + return json.Unmarshal(*c.Result, reply) +} + +// DecodeClientResponse decodes the response body of a client request into +// the interface reply. +func DecodeClientResponse(r io.Reader, reply interface{}) error { + var c clientResponse + if err := json.NewDecoder(r).Decode(&c); err != nil { + return err + } + return c.decode(reply) +} + +type ErrorCode int + +const ( + E_PARSE ErrorCode = -32700 + E_INVALID_REQ ErrorCode = -32600 + E_NO_METHOD ErrorCode = -32601 + E_BAD_PARAMS ErrorCode = -32602 + E_INTERNAL ErrorCode = -32603 + E_SERVER ErrorCode = -32000 +) + +var ErrNullResult = errors.New("result is null") + +type Error struct { + // A Number that indicates the error type that occurred. + Code ErrorCode `json:"code"` /* required */ + + // A String providing a short description of the error. + // The message SHOULD be limited to a concise single sentence. + Message string `json:"message"` /* required */ + + // A Primitive or Structured value that contains additional information about the error. + Data interface{} `json:"data"` /* optional */ +} + +func (e *Error) Error() string { + return e.Message +} diff --git a/pkg/aria2/rpc/notification.go b/pkg/aria2/rpc/notification.go new file mode 100644 index 0000000000000000000000000000000000000000..ea44cd6eaf37cd2db7a51a6c6a7e524d5ac543b6 --- /dev/null +++ b/pkg/aria2/rpc/notification.go @@ -0,0 +1,44 @@ +package rpc + +import ( + log "github.com/sirupsen/logrus" +) + +type Event struct { + Gid string `json:"gid"` // GID of the download +} + +// The RPC server might send notifications to the client. +// Notifications is unidirectional, therefore the client which receives the notification must not respond to it. +// The method signature of a notification is much like a normal method request but lacks the id key + +type websocketResponse struct { + clientResponse + Method string `json:"method"` + Params []Event `json:"params"` +} + +// Notifier handles rpc notification from aria2 server +type Notifier interface { + // OnDownloadStart will be sent when a download is started. + OnDownloadStart([]Event) + // OnDownloadPause will be sent when a download is paused. + OnDownloadPause([]Event) + // OnDownloadStop will be sent when a download is stopped by the user. + OnDownloadStop([]Event) + // OnDownloadComplete will be sent when a download is complete. For BitTorrent downloads, this notification is sent when the download is complete and seeding is over. + OnDownloadComplete([]Event) + // OnDownloadError will be sent when a download is stopped due to an error. + OnDownloadError([]Event) + // OnBtDownloadComplete will be sent when a torrent download is complete but seeding is still going on. + OnBtDownloadComplete([]Event) +} + +type DummyNotifier struct{} + +func (DummyNotifier) OnDownloadStart(events []Event) { log.Printf("%s started.", events) } +func (DummyNotifier) OnDownloadPause(events []Event) { log.Printf("%s paused.", events) } +func (DummyNotifier) OnDownloadStop(events []Event) { log.Printf("%s stopped.", events) } +func (DummyNotifier) OnDownloadComplete(events []Event) { log.Printf("%s completed.", events) } +func (DummyNotifier) OnDownloadError(events []Event) { log.Printf("%s error.", events) } +func (DummyNotifier) OnBtDownloadComplete(events []Event) { log.Printf("bt %s completed.", events) } diff --git a/pkg/aria2/rpc/proc.go b/pkg/aria2/rpc/proc.go new file mode 100644 index 0000000000000000000000000000000000000000..0184e6dc6b002f89057a02dd8aa5385e7e814787 --- /dev/null +++ b/pkg/aria2/rpc/proc.go @@ -0,0 +1,42 @@ +package rpc + +import "sync" + +type ResponseProcFn func(resp clientResponse) error + +type ResponseProcessor struct { + cbs map[uint64]ResponseProcFn + mu *sync.RWMutex +} + +func NewResponseProcessor() *ResponseProcessor { + return &ResponseProcessor{ + make(map[uint64]ResponseProcFn), + &sync.RWMutex{}, + } +} + +func (r *ResponseProcessor) Add(id uint64, fn ResponseProcFn) { + r.mu.Lock() + r.cbs[id] = fn + r.mu.Unlock() +} + +func (r *ResponseProcessor) remove(id uint64) { + r.mu.Lock() + delete(r.cbs, id) + r.mu.Unlock() +} + +// Process called by recv routine +func (r *ResponseProcessor) Process(resp clientResponse) error { + id := *resp.Id + r.mu.RLock() + fn, ok := r.cbs[id] + r.mu.RUnlock() + if ok && fn != nil { + defer r.remove(id) + return fn(resp) + } + return nil +} diff --git a/pkg/aria2/rpc/proto.go b/pkg/aria2/rpc/proto.go new file mode 100644 index 0000000000000000000000000000000000000000..3f5bf6db0754c9866df71f4d4a71d8efbf2a5cd8 --- /dev/null +++ b/pkg/aria2/rpc/proto.go @@ -0,0 +1,40 @@ +package rpc + +// Protocol is a set of rpc methods that aria2 daemon supports +type Protocol interface { + AddURI(uris []string, options ...interface{}) (gid string, err error) + AddTorrent(filename string, options ...interface{}) (gid string, err error) + AddMetalink(filename string, options ...interface{}) (gid []string, err error) + Remove(gid string) (g string, err error) + ForceRemove(gid string) (g string, err error) + Pause(gid string) (g string, err error) + PauseAll() (ok string, err error) + ForcePause(gid string) (g string, err error) + ForcePauseAll() (ok string, err error) + Unpause(gid string) (g string, err error) + UnpauseAll() (ok string, err error) + TellStatus(gid string, keys ...string) (info StatusInfo, err error) + GetURIs(gid string) (infos []URIInfo, err error) + GetFiles(gid string) (infos []FileInfo, err error) + GetPeers(gid string) (infos []PeerInfo, err error) + GetServers(gid string) (infos []ServerInfo, err error) + TellActive(keys ...string) (infos []StatusInfo, err error) + TellWaiting(offset, num int, keys ...string) (infos []StatusInfo, err error) + TellStopped(offset, num int, keys ...string) (infos []StatusInfo, err error) + ChangePosition(gid string, pos int, how string) (p int, err error) + ChangeURI(gid string, fileindex int, delUris []string, addUris []string, position ...int) (p []int, err error) + GetOption(gid string) (m Option, err error) + ChangeOption(gid string, option Option) (ok string, err error) + GetGlobalOption() (m Option, err error) + ChangeGlobalOption(options Option) (ok string, err error) + GetGlobalStat() (info GlobalStatInfo, err error) + PurgeDownloadResult() (ok string, err error) + RemoveDownloadResult(gid string) (ok string, err error) + GetVersion() (info VersionInfo, err error) + GetSessionInfo() (info SessionInfo, err error) + Shutdown() (ok string, err error) + ForceShutdown() (ok string, err error) + SaveSession() (ok string, err error) + Multicall(methods []Method) (r []interface{}, err error) + ListMethods() (methods []string, err error) +} diff --git a/pkg/aria2/rpc/resp.go b/pkg/aria2/rpc/resp.go new file mode 100644 index 0000000000000000000000000000000000000000..f1d5fdd26ecac675c7c95b532456a767e129df85 --- /dev/null +++ b/pkg/aria2/rpc/resp.go @@ -0,0 +1,102 @@ +//go:generate easyjson -all + +package rpc + +// StatusInfo represents response of aria2.tellStatus +type StatusInfo struct { + Gid string `json:"gid"` // GID of the download. + Status string `json:"status"` // active for currently downloading/seeding downloads. waiting for downloads in the queue; download is not started. paused for paused downloads. error for downloads that were stopped because of error. complete for stopped and completed downloads. removed for the downloads removed by user. + TotalLength string `json:"totalLength"` // Total length of the download in bytes. + CompletedLength string `json:"completedLength"` // Completed length of the download in bytes. + UploadLength string `json:"uploadLength"` // Uploaded length of the download in bytes. + BitField string `json:"bitfield"` // Hexadecimal representation of the download progress. The highest bit corresponds to the piece at index 0. Any set bits indicate loaded pieces, while unset bits indicate not yet loaded and/or missing pieces. Any overflow bits at the end are set to zero. When the download was not started yet, this key will not be included in the response. + DownloadSpeed string `json:"downloadSpeed"` // Download speed of this download measured in bytes/sec. + UploadSpeed string `json:"uploadSpeed"` // Upload speed of this download measured in bytes/sec. + InfoHash string `json:"infoHash"` // InfoHash. BitTorrent only. + NumSeeders string `json:"numSeeders"` // The number of seeders aria2 has connected to. BitTorrent only. + Seeder string `json:"seeder"` // true if the local endpoint is a seeder. Otherwise, false. BitTorrent only. + PieceLength string `json:"pieceLength"` // Piece length in bytes. + NumPieces string `json:"numPieces"` // The number of pieces. + Connections string `json:"connections"` // The number of peers/servers aria2 has connected to. + ErrorCode string `json:"errorCode"` // The code of the last error for this item, if any. The value is a string. The error codes are defined in the EXIT STATUS section. This value is only available for stopped/completed downloads. + ErrorMessage string `json:"errorMessage"` // The (hopefully) human-readable error message associated to errorCode. + FollowedBy []string `json:"followedBy"` // List of GIDs which are generated as the result of this download. For example, when aria2 downloads a Metalink file, it generates downloads described in the Metalink (see the --follow-metalink option). This value is useful to track auto-generated downloads. If there are no such downloads, this key will not be included in the response. + BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response. + Dir string `json:"dir"` // Directory to save files. + Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method. + BitTorrent struct { + AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format. + Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available. + CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds. + Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi. + Info struct { + Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available. + } `json:"info"` // Struct which contains data from Info dictionary. It contains following keys. + } `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys. +} + +// URIInfo represents an element of response of aria2.getUris +type URIInfo struct { + URI string `json:"uri"` // URI + Status string `json:"status"` // 'used' if the URI is in use. 'waiting' if the URI is still waiting in the queue. +} + +// FileInfo represents an element of response of aria2.getFiles +type FileInfo struct { + Index string `json:"index"` // Index of the file, starting at 1, in the same order as files appear in the multi-file torrent. + Path string `json:"path"` // File path. + Length string `json:"length"` // File size in bytes. + CompletedLength string `json:"completedLength"` // Completed length of this file in bytes. Please note that it is possible that sum of completedLength is less than the completedLength returned by the aria2.tellStatus() method. This is because completedLength in aria2.getFiles() only includes completed pieces. On the other hand, completedLength in aria2.tellStatus() also includes partially completed pieces. + Selected string `json:"selected"` // true if this file is selected by --select-file option. If --select-file is not specified or this is single-file torrent or not a torrent download at all, this value is always true. Otherwise false. + URIs []URIInfo `json:"uris"` // Returns a list of URIs for this file. The element type is the same struct used in the aria2.getUris() method. +} + +// PeerInfo represents an element of response of aria2.getPeers +type PeerInfo struct { + PeerId string `json:"peerId"` // Percent-encoded peer ID. + IP string `json:"ip"` // IP address of the peer. + Port string `json:"port"` // Port number of the peer. + BitField string `json:"bitfield"` // Hexadecimal representation of the download progress of the peer. The highest bit corresponds to the piece at index 0. Set bits indicate the piece is available and unset bits indicate the piece is missing. Any spare bits at the end are set to zero. + AmChoking string `json:"amChoking"` // true if aria2 is choking the peer. Otherwise false. + PeerChoking string `json:"peerChoking"` // true if the peer is choking aria2. Otherwise false. + DownloadSpeed string `json:"downloadSpeed"` // Download speed (byte/sec) that this client obtains from the peer. + UploadSpeed string `json:"uploadSpeed"` // Upload speed(byte/sec) that this client uploads to the peer. + Seeder string `json:"seeder"` // true if this peer is a seeder. Otherwise false. +} + +// ServerInfo represents an element of response of aria2.getServers +type ServerInfo struct { + Index string `json:"index"` // Index of the file, starting at 1, in the same order as files appear in the multi-file metalink. + Servers []struct { + URI string `json:"uri"` // Original URI. + CurrentURI string `json:"currentUri"` // This is the URI currently used for downloading. If redirection is involved, currentUri and uri may differ. + DownloadSpeed string `json:"downloadSpeed"` // Download speed (byte/sec) + } `json:"servers"` // A list of structs which contain the following keys. +} + +// GlobalStatInfo represents response of aria2.getGlobalStat +type GlobalStatInfo struct { + DownloadSpeed string `json:"downloadSpeed"` // Overall download speed (byte/sec). + UploadSpeed string `json:"uploadSpeed"` // Overall upload speed(byte/sec). + NumActive string `json:"numActive"` // The number of active downloads. + NumWaiting string `json:"numWaiting"` // The number of waiting downloads. + NumStopped string `json:"numStopped"` // The number of stopped downloads in the current session. This value is capped by the --max-download-result option. + NumStoppedTotal string `json:"numStoppedTotal"` // The number of stopped downloads in the current session and not capped by the --max-download-result option. +} + +// VersionInfo represents response of aria2.getVersion +type VersionInfo struct { + Version string `json:"version"` // Version number of aria2 as a string. + Features []string `json:"enabledFeatures"` // List of enabled features. Each feature is given as a string. +} + +// SessionInfo represents response of aria2.getSessionInfo +type SessionInfo struct { + Id string `json:"sessionId"` // Session ID, which is generated each time when aria2 is invoked. +} + +// Method is an element of parameters used in system.multicall +type Method struct { + Name string `json:"methodName"` // Method name to call + Params []interface{} `json:"params"` // Array containing parameters to the method call +} diff --git a/pkg/chanio/chanio.go b/pkg/chanio/chanio.go new file mode 100644 index 0000000000000000000000000000000000000000..074229fe557a90a9211f2feb321d8a7cf374a12d --- /dev/null +++ b/pkg/chanio/chanio.go @@ -0,0 +1,62 @@ +package chanio + +import ( + "io" + "sync/atomic" +) + +type ChanIO struct { + cl atomic.Bool + c chan []byte + buf []byte +} + +func New() *ChanIO { + return &ChanIO{ + cl: atomic.Bool{}, + c: make(chan []byte), + buf: make([]byte, 0), + } +} + +func (c *ChanIO) Read(p []byte) (int, error) { + if c.cl.Load() { + if len(c.buf) == 0 { + return 0, io.EOF + } + n := copy(p, c.buf) + if len(c.buf) > n { + c.buf = c.buf[n:] + } else { + c.buf = make([]byte, 0) + } + return n, nil + } + for len(c.buf) < len(p) && !c.cl.Load() { + c.buf = append(c.buf, <-c.c...) + } + n := copy(p, c.buf) + if len(c.buf) > n { + c.buf = c.buf[n:] + } else { + c.buf = make([]byte, 0) + } + return n, nil +} + +func (c *ChanIO) Write(p []byte) (int, error) { + if c.cl.Load() { + return 0, io.ErrClosedPipe + } + c.c <- p + return len(p), nil +} + +func (c *ChanIO) Close() error { + if c.cl.Load() { + return io.ErrClosedPipe + } + c.cl.Store(true) + close(c.c) + return nil +} diff --git a/pkg/cookie/cookie.go b/pkg/cookie/cookie.go new file mode 100644 index 0000000000000000000000000000000000000000..8a6ca859351676eb5deff29a9a1d1a32806c3540 --- /dev/null +++ b/pkg/cookie/cookie.go @@ -0,0 +1,59 @@ +package cookie + +import ( + "net/http" + "strings" +) + +func Parse(str string) []*http.Cookie { + header := http.Header{} + header.Add("Cookie", str) + request := http.Request{Header: header} + return request.Cookies() +} + +func ToString(cookies []*http.Cookie) string { + if cookies == nil { + return "" + } + cookieStrings := make([]string, len(cookies)) + for i, cookie := range cookies { + cookieStrings[i] = cookie.String() + } + return strings.Join(cookieStrings, ";") +} + +func SetCookie(cookies []*http.Cookie, name, value string) []*http.Cookie { + for i, cookie := range cookies { + if cookie.Name == name { + cookies[i].Value = value + return cookies + } + } + cookies = append(cookies, &http.Cookie{Name: name, Value: value}) + return cookies +} + +func GetCookie(cookies []*http.Cookie, name string) *http.Cookie { + for _, cookie := range cookies { + if cookie.Name == name { + return cookie + } + } + return nil +} + +func SetStr(cookiesStr, name, value string) string { + cookies := Parse(cookiesStr) + cookies = SetCookie(cookies, name, value) + return ToString(cookies) +} + +func GetStr(cookiesStr, name string) string { + cookies := Parse(cookiesStr) + cookie := GetCookie(cookies, name) + if cookie == nil { + return "" + } + return cookie.Value +} diff --git a/pkg/cron/cron.go b/pkg/cron/cron.go new file mode 100644 index 0000000000000000000000000000000000000000..3a3e978cfa0cbcdbd0ee188902b6f815367fd3b1 --- /dev/null +++ b/pkg/cron/cron.go @@ -0,0 +1,39 @@ +package cron + +import "time" + +type Cron struct { + d time.Duration + ch chan struct{} +} + +func NewCron(d time.Duration) *Cron { + return &Cron{ + d: d, + ch: make(chan struct{}), + } +} + +func (c *Cron) Do(f func()) { + go func() { + ticker := time.NewTicker(c.d) + defer ticker.Stop() + for { + select { + case <-ticker.C: + f() + case <-c.ch: + return + } + } + }() +} + +func (c *Cron) Stop() { + select { + case _, _ = <-c.ch: + default: + c.ch <- struct{}{} + close(c.ch) + } +} diff --git a/pkg/cron/cron_test.go b/pkg/cron/cron_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1bd7cf2dfa32984af7de137e14ed3fd26773a3d2 --- /dev/null +++ b/pkg/cron/cron_test.go @@ -0,0 +1,16 @@ +package cron + +import ( + "testing" + "time" +) + +func TestCron(t *testing.T) { + c := NewCron(time.Second) + c.Do(func() { + t.Logf("cron log") + }) + time.Sleep(time.Second * 3) + c.Stop() + c.Stop() +} diff --git a/pkg/errgroup/errgroup.go b/pkg/errgroup/errgroup.go new file mode 100644 index 0000000000000000000000000000000000000000..858df044c268775b077876abf58e78dd2f9e9fb8 --- /dev/null +++ b/pkg/errgroup/errgroup.go @@ -0,0 +1,93 @@ +package errgroup + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + + "github.com/avast/retry-go" +) + +type token struct{} +type Group struct { + cancel func(error) + ctx context.Context + opts []retry.Option + + success uint64 + + wg sync.WaitGroup + sem chan token +} + +func NewGroupWithContext(ctx context.Context, limit int, retryOpts ...retry.Option) (*Group, context.Context) { + ctx, cancel := context.WithCancelCause(ctx) + return (&Group{cancel: cancel, ctx: ctx, opts: append(retryOpts, retry.Context(ctx))}).SetLimit(limit), ctx +} + +func (g *Group) done() { + if g.sem != nil { + <-g.sem + } + g.wg.Done() + atomic.AddUint64(&g.success, 1) +} + +func (g *Group) Wait() error { + g.wg.Wait() + return context.Cause(g.ctx) +} + +func (g *Group) Go(f func(ctx context.Context) error) { + if g.sem != nil { + g.sem <- token{} + } + + g.wg.Add(1) + go func() { + defer g.done() + if err := retry.Do(func() error { return f(g.ctx) }, g.opts...); err != nil { + g.cancel(err) + } + }() +} + +func (g *Group) TryGo(f func(ctx context.Context) error) bool { + if g.sem != nil { + select { + case g.sem <- token{}: + default: + return false + } + } + + g.wg.Add(1) + go func() { + defer g.done() + if err := retry.Do(func() error { return f(g.ctx) }, g.opts...); err != nil { + g.cancel(err) + } + }() + return true +} + +func (g *Group) SetLimit(n int) *Group { + if len(g.sem) != 0 { + panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem))) + } + if n > 0 { + g.sem = make(chan token, n) + } else { + g.sem = nil + } + return g +} + +func (g *Group) Success() uint64 { + return atomic.LoadUint64(&g.success) +} + +func (g *Group) Err() error { + return context.Cause(g.ctx) +} diff --git a/pkg/generic/queue.go b/pkg/generic/queue.go new file mode 100644 index 0000000000000000000000000000000000000000..0ccc4bd9f8ca7f8fd6f9eaf2d100f02b8c795c0f --- /dev/null +++ b/pkg/generic/queue.go @@ -0,0 +1,75 @@ +package generic + +type Queue[T any] struct { + queue []T +} + +func NewQueue[T any]() *Queue[T] { + return &Queue[T]{queue: make([]T, 0)} +} + +func (q *Queue[T]) Push(v T) { + q.queue = append(q.queue, v) +} + +func (q *Queue[T]) Pop() T { + v := q.queue[0] + q.queue = q.queue[1:] + return v +} + +func (q *Queue[T]) Len() int { + return len(q.queue) +} + +func (q *Queue[T]) IsEmpty() bool { + return len(q.queue) == 0 +} + +func (q *Queue[T]) Clear() { + q.queue = nil +} + +func (q *Queue[T]) Peek() T { + return q.queue[0] +} + +func (q *Queue[T]) PeekN(n int) []T { + return q.queue[:n] +} + +func (q *Queue[T]) PopN(n int) []T { + v := q.queue[:n] + q.queue = q.queue[n:] + return v +} + +func (q *Queue[T]) PopAll() []T { + v := q.queue + q.queue = nil + return v +} + +func (q *Queue[T]) PopWhile(f func(T) bool) []T { + var i int + for i = 0; i < len(q.queue); i++ { + if !f(q.queue[i]) { + break + } + } + v := q.queue[:i] + q.queue = q.queue[i:] + return v +} + +func (q *Queue[T]) PopUntil(f func(T) bool) []T { + var i int + for i = 0; i < len(q.queue); i++ { + if f(q.queue[i]) { + break + } + } + v := q.queue[:i] + q.queue = q.queue[i:] + return v +} diff --git a/pkg/generic_sync/map.go b/pkg/generic_sync/map.go new file mode 100644 index 0000000000000000000000000000000000000000..96612f0cc445b37d4dd05f23dc8e0222c51fd769 --- /dev/null +++ b/pkg/generic_sync/map.go @@ -0,0 +1,412 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package generic_sync + +import ( + "sync" + "sync/atomic" + "unsafe" +) + +// MapOf is like a Go map[interface{}]interface{} but is safe for concurrent use +// by multiple goroutines without additional locking or coordination. +// Loads, stores, and deletes run in amortized constant time. +// +// The MapOf type is specialized. Most code should use a plain Go map instead, +// with separate locking or coordination, for better type safety and to make it +// easier to maintain other invariants along with the map content. +// +// The MapOf type is optimized for two common use cases: (1) when the entry for a given +// key is only ever written once but read many times, as in caches that only grow, +// or (2) when multiple goroutines read, write, and overwrite entries for disjoint +// sets of keys. In these two cases, use of a MapOf may significantly reduce lock +// contention compared to a Go map paired with a separate Mutex or RWMutex. +// +// The zero MapOf is empty and ready for use. A MapOf must not be copied after first use. +type MapOf[K comparable, V any] struct { + mu sync.Mutex + + // read contains the portion of the map's contents that are safe for + // concurrent access (with or without mu held). + // + // The read field itself is always safe to load, but must only be stored with + // mu held. + // + // Entries stored in read may be updated concurrently without mu, but updating + // a previously-expunged entry requires that the entry be copied to the dirty + // map and unexpunged with mu held. + read atomic.Value // readOnly + + // dirty contains the portion of the map's contents that require mu to be + // held. To ensure that the dirty map can be promoted to the read map quickly, + // it also includes all of the non-expunged entries in the read map. + // + // Expunged entries are not stored in the dirty map. An expunged entry in the + // clean map must be unexpunged and added to the dirty map before a new value + // can be stored to it. + // + // If the dirty map is nil, the next write to the map will initialize it by + // making a shallow copy of the clean map, omitting stale entries. + dirty map[K]*entry[V] + + // misses counts the number of loads since the read map was last updated that + // needed to lock mu to determine whether the key was present. + // + // Once enough misses have occurred to cover the cost of copying the dirty + // map, the dirty map will be promoted to the read map (in the unamended + // state) and the next store to the map will make a new dirty copy. + misses int +} + +// readOnly is an immutable struct stored atomically in the MapOf.read field. +type readOnly[K comparable, V any] struct { + m map[K]*entry[V] + amended bool // true if the dirty map contains some key not in m. +} + +// expunged is an arbitrary pointer that marks entries which have been deleted +// from the dirty map. +var expunged = unsafe.Pointer(new(interface{})) + +// An entry is a slot in the map corresponding to a particular key. +type entry[V any] struct { + // p points to the interface{} value stored for the entry. + // + // If p == nil, the entry has been deleted and m.dirty == nil. + // + // If p == expunged, the entry has been deleted, m.dirty != nil, and the entry + // is missing from m.dirty. + // + // Otherwise, the entry is valid and recorded in m.read.m[key] and, if m.dirty + // != nil, in m.dirty[key]. + // + // An entry can be deleted by atomic replacement with nil: when m.dirty is + // next created, it will atomically replace nil with expunged and leave + // m.dirty[key] unset. + // + // An entry's associated value can be updated by atomic replacement, provided + // p != expunged. If p == expunged, an entry's associated value can be updated + // only after first setting m.dirty[key] = e so that lookups using the dirty + // map find the entry. + p unsafe.Pointer // *interface{} +} + +func newEntry[V any](i V) *entry[V] { + return &entry[V]{p: unsafe.Pointer(&i)} +} + +// Load returns the value stored in the map for a key, or nil if no +// value is present. +// The ok result indicates whether value was found in the map. +func (m *MapOf[K, V]) Load(key K) (value V, ok bool) { + read, _ := m.read.Load().(readOnly[K, V]) + e, ok := read.m[key] + if !ok && read.amended { + m.mu.Lock() + // Avoid reporting a spurious miss if m.dirty got promoted while we were + // blocked on m.mu. (If further loads of the same key will not miss, it's + // not worth copying the dirty map for this key.) + read, _ = m.read.Load().(readOnly[K, V]) + e, ok = read.m[key] + if !ok && read.amended { + e, ok = m.dirty[key] + // Regardless of whether the entry was present, record a miss: this key + // will take the slow path until the dirty map is promoted to the read + // map. + m.missLocked() + } + m.mu.Unlock() + } + if !ok { + return value, false + } + return e.load() +} + +func (m *MapOf[K, V]) Has(key K) bool { + _, ok := m.Load(key) + return ok +} + +func (e *entry[V]) load() (value V, ok bool) { + p := atomic.LoadPointer(&e.p) + if p == nil || p == expunged { + return value, false + } + return *(*V)(p), true +} + +// Store sets the value for a key. +func (m *MapOf[K, V]) Store(key K, value V) { + read, _ := m.read.Load().(readOnly[K, V]) + if e, ok := read.m[key]; ok && e.tryStore(&value) { + return + } + + m.mu.Lock() + read, _ = m.read.Load().(readOnly[K, V]) + if e, ok := read.m[key]; ok { + if e.unexpungeLocked() { + // The entry was previously expunged, which implies that there is a + // non-nil dirty map and this entry is not in it. + m.dirty[key] = e + } + e.storeLocked(&value) + } else if e, ok := m.dirty[key]; ok { + e.storeLocked(&value) + } else { + if !read.amended { + // We're adding the first new key to the dirty map. + // Make sure it is allocated and mark the read-only map as incomplete. + m.dirtyLocked() + m.read.Store(readOnly[K, V]{m: read.m, amended: true}) + } + m.dirty[key] = newEntry(value) + } + m.mu.Unlock() +} + +// tryStore stores a value if the entry has not been expunged. +// +// If the entry is expunged, tryStore returns false and leaves the entry +// unchanged. +func (e *entry[V]) tryStore(i *V) bool { + for { + p := atomic.LoadPointer(&e.p) + if p == expunged { + return false + } + if atomic.CompareAndSwapPointer(&e.p, p, unsafe.Pointer(i)) { + return true + } + } +} + +// unexpungeLocked ensures that the entry is not marked as expunged. +// +// If the entry was previously expunged, it must be added to the dirty map +// before m.mu is unlocked. +func (e *entry[V]) unexpungeLocked() (wasExpunged bool) { + return atomic.CompareAndSwapPointer(&e.p, expunged, nil) +} + +// storeLocked unconditionally stores a value to the entry. +// +// The entry must be known not to be expunged. +func (e *entry[V]) storeLocked(i *V) { + atomic.StorePointer(&e.p, unsafe.Pointer(i)) +} + +// LoadOrStore returns the existing value for the key if present. +// Otherwise, it stores and returns the given value. +// The loaded result is true if the value was loaded, false if stored. +func (m *MapOf[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + // Avoid locking if it's a clean hit. + read, _ := m.read.Load().(readOnly[K, V]) + if e, ok := read.m[key]; ok { + actual, loaded, ok := e.tryLoadOrStore(value) + if ok { + return actual, loaded + } + } + + m.mu.Lock() + read, _ = m.read.Load().(readOnly[K, V]) + if e, ok := read.m[key]; ok { + if e.unexpungeLocked() { + m.dirty[key] = e + } + actual, loaded, _ = e.tryLoadOrStore(value) + } else if e, ok := m.dirty[key]; ok { + actual, loaded, _ = e.tryLoadOrStore(value) + m.missLocked() + } else { + if !read.amended { + // We're adding the first new key to the dirty map. + // Make sure it is allocated and mark the read-only map as incomplete. + m.dirtyLocked() + m.read.Store(readOnly[K, V]{m: read.m, amended: true}) + } + m.dirty[key] = newEntry(value) + actual, loaded = value, false + } + m.mu.Unlock() + + return actual, loaded +} + +// tryLoadOrStore atomically loads or stores a value if the entry is not +// expunged. +// +// If the entry is expunged, tryLoadOrStore leaves the entry unchanged and +// returns with ok==false. +func (e *entry[V]) tryLoadOrStore(i V) (actual V, loaded, ok bool) { + p := atomic.LoadPointer(&e.p) + if p == expunged { + return actual, false, false + } + if p != nil { + return *(*V)(p), true, true + } + + // Copy the interface after the first load to make this method more amenable + // to escape analysis: if we hit the "load" path or the entry is expunged, we + // shouldn'V bother heap-allocating. + ic := i + for { + if atomic.CompareAndSwapPointer(&e.p, nil, unsafe.Pointer(&ic)) { + return i, false, true + } + p = atomic.LoadPointer(&e.p) + if p == expunged { + return actual, false, false + } + if p != nil { + return *(*V)(p), true, true + } + } +} + +// Delete deletes the value for a key. +func (m *MapOf[K, V]) Delete(key K) { + read, _ := m.read.Load().(readOnly[K, V]) + e, ok := read.m[key] + if !ok && read.amended { + m.mu.Lock() + read, _ = m.read.Load().(readOnly[K, V]) + e, ok = read.m[key] + if !ok && read.amended { + delete(m.dirty, key) + } + m.mu.Unlock() + } + if ok { + e.delete() + } +} + +func (e *entry[V]) delete() (hadValue bool) { + for { + p := atomic.LoadPointer(&e.p) + if p == nil || p == expunged { + return false + } + if atomic.CompareAndSwapPointer(&e.p, p, nil) { + return true + } + } +} + +// Range calls f sequentially for each key and value present in the map. +// If f returns false, range stops the iteration. +// +// Range does not necessarily correspond to any consistent snapshot of the MapOf's +// contents: no key will be visited more than once, but if the value for any key +// is stored or deleted concurrently, Range may reflect any mapping for that key +// from any point during the Range call. +// +// Range may be O(N) with the number of elements in the map even if f returns +// false after a constant number of calls. +func (m *MapOf[K, V]) Range(f func(key K, value V) bool) { + // We need to be able to iterate over all of the keys that were already + // present at the start of the call to Range. + // If read.amended is false, then read.m satisfies that property without + // requiring us to hold m.mu for a long time. + read, _ := m.read.Load().(readOnly[K, V]) + if read.amended { + // m.dirty contains keys not in read.m. Fortunately, Range is already O(N) + // (assuming the caller does not break out early), so a call to Range + // amortizes an entire copy of the map: we can promote the dirty copy + // immediately! + m.mu.Lock() + read, _ = m.read.Load().(readOnly[K, V]) + if read.amended { + read = readOnly[K, V]{m: m.dirty} + m.read.Store(read) + m.dirty = nil + m.misses = 0 + } + m.mu.Unlock() + } + + for k, e := range read.m { + v, ok := e.load() + if !ok { + continue + } + if !f(k, v) { + break + } + } +} + +// Values returns a slice of the values in the map. +func (m *MapOf[K, V]) Values() []V { + var values []V + m.Range(func(key K, value V) bool { + values = append(values, value) + return true + }) + return values +} + +func (m *MapOf[K, V]) Count() int { + return len(m.dirty) +} + +func (m *MapOf[K, V]) Empty() bool { + return m.Count() == 0 +} + +func (m *MapOf[K, V]) ToMap() map[K]V { + ans := make(map[K]V) + m.Range(func(key K, value V) bool { + ans[key] = value + return true + }) + return ans +} + +func (m *MapOf[K, V]) Clear() { + m.Range(func(key K, value V) bool { + m.Delete(key) + return true + }) +} + +func (m *MapOf[K, V]) missLocked() { + m.misses++ + if m.misses < len(m.dirty) { + return + } + m.read.Store(readOnly[K, V]{m: m.dirty}) + m.dirty = nil + m.misses = 0 +} + +func (m *MapOf[K, V]) dirtyLocked() { + if m.dirty != nil { + return + } + + read, _ := m.read.Load().(readOnly[K, V]) + m.dirty = make(map[K]*entry[V], len(read.m)) + for k, e := range read.m { + if !e.tryExpungeLocked() { + m.dirty[k] = e + } + } +} + +func (e *entry[V]) tryExpungeLocked() (isExpunged bool) { + p := atomic.LoadPointer(&e.p) + for p == nil { + if atomic.CompareAndSwapPointer(&e.p, nil, expunged) { + return true + } + p = atomic.LoadPointer(&e.p) + } + return p == expunged +} diff --git a/pkg/generic_sync/map_test.go b/pkg/generic_sync/map_test.go new file mode 100644 index 0000000000000000000000000000000000000000..22d7831979633b5c456f9110ca0a04f1c70406c7 --- /dev/null +++ b/pkg/generic_sync/map_test.go @@ -0,0 +1,74 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package generic_sync_test + +import ( + "math/rand" + "runtime" + "sync" + "testing" + + "github.com/alist-org/alist/v3/pkg/generic_sync" +) + +func TestConcurrentRange(t *testing.T) { + const mapSize = 1 << 10 + + m := new(generic_sync.MapOf[int64, int64]) + for n := int64(1); n <= mapSize; n++ { + m.Store(n, int64(n)) + } + + done := make(chan struct{}) + var wg sync.WaitGroup + defer func() { + close(done) + wg.Wait() + }() + for g := int64(runtime.GOMAXPROCS(0)); g > 0; g-- { + r := rand.New(rand.NewSource(g)) + wg.Add(1) + go func(g int64) { + defer wg.Done() + for i := int64(0); ; i++ { + select { + case <-done: + return + default: + } + for n := int64(1); n < mapSize; n++ { + if r.Int63n(mapSize) == 0 { + m.Store(n, n*i*g) + } else { + m.Load(n) + } + } + } + }(g) + } + + iters := 1 << 10 + if testing.Short() { + iters = 16 + } + for n := iters; n > 0; n-- { + seen := make(map[int64]bool, mapSize) + + m.Range(func(k, v int64) bool { + if v%k != 0 { + t.Fatalf("while Storing multiples of %v, Range saw value %v", k, v) + } + if seen[k] { + t.Fatalf("Range visited key %v twice", k) + } + seen[k] = true + return true + }) + + if len(seen) != mapSize { + t.Fatalf("Range visited %v elements of %v-element MapOf", len(seen), mapSize) + } + } +} diff --git a/pkg/gowebdav/.gitignore b/pkg/gowebdav/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..394b2f5f86658aea0878236246fb411e323bdd34 --- /dev/null +++ b/pkg/gowebdav/.gitignore @@ -0,0 +1,21 @@ +# Folders to ignore +/src +/bin +/pkg +/gowebdav +/.idea + +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +.vscode/ \ No newline at end of file diff --git a/pkg/gowebdav/.travis.yml b/pkg/gowebdav/.travis.yml new file mode 100644 index 0000000000000000000000000000000000000000..76bfb654c626b88745eb23c3061a75ec1e19ab98 --- /dev/null +++ b/pkg/gowebdav/.travis.yml @@ -0,0 +1,10 @@ +language: go + +go: + - "1.x" + +install: + - go get ./... + +script: + - go test -v --short ./... \ No newline at end of file diff --git a/pkg/gowebdav/LICENSE b/pkg/gowebdav/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a7cd4420f08a09a280d4a497336d4b9d14123b0c --- /dev/null +++ b/pkg/gowebdav/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2014, Studio B12 GmbH +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pkg/gowebdav/Makefile b/pkg/gowebdav/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..c6a0062c15d8a53ed02283d29190f39637dc215f --- /dev/null +++ b/pkg/gowebdav/Makefile @@ -0,0 +1,33 @@ +BIN := gowebdav +SRC := $(wildcard *.go) cmd/gowebdav/main.go + +all: test cmd + +cmd: ${BIN} + +${BIN}: ${SRC} + go build -o $@ ./cmd/gowebdav + +test: + go test -v --short ./... + +api: + @sed '/^## API$$/,$$d' -i README.md + @echo '## API' >> README.md + @godoc2md github.com/studio-b12/gowebdav | sed '/^$$/N;/^\n$$/D' |\ + sed '2d' |\ + sed 's/\/src\/github.com\/studio-b12\/gowebdav\//https:\/\/github.com\/studio-b12\/gowebdav\/blob\/master\//g' |\ + sed 's/\/src\/target\//https:\/\/github.com\/studio-b12\/gowebdav\/blob\/master\//g' |\ + sed 's/^#/##/g' >> README.md + +check: + gofmt -w -s $(SRC) + @echo + gocyclo -over 15 . + @echo + golint ./... + +clean: + @rm -f ${BIN} + +.PHONY: all cmd clean test api check diff --git a/pkg/gowebdav/README.md b/pkg/gowebdav/README.md new file mode 100644 index 0000000000000000000000000000000000000000..31d9fe7bd3cdc6aa859e6e54934956e76ce12446 --- /dev/null +++ b/pkg/gowebdav/README.md @@ -0,0 +1,564 @@ +# GoWebDAV + +[![Build Status](https://travis-ci.org/studio-b12/gowebdav.svg?branch=master)](https://travis-ci.org/studio-b12/gowebdav) +[![GoDoc](https://godoc.org/github.com/studio-b12/gowebdav?status.svg)](https://godoc.org/github.com/studio-b12/gowebdav) +[![Go Report Card](https://goreportcard.com/badge/github.com/studio-b12/gowebdav)](https://goreportcard.com/report/github.com/studio-b12/gowebdav) + +A golang WebDAV client library. + +## Main features +`gowebdav` library allows to perform following actions on the remote WebDAV server: +* [create path](#create-path-on-a-webdav-server) +* [get files list](#get-files-list) +* [download file](#download-file-to-byte-array) +* [upload file](#upload-file-from-byte-array) +* [get information about specified file/folder](#get-information-about-specified-filefolder) +* [move file to another location](#move-file-to-another-location) +* [copy file to another location](#copy-file-to-another-location) +* [delete file](#delete-file) + +## Usage + +First of all you should create `Client` instance using `NewClient()` function: + +```go +root := "https://webdav.mydomain.me" +user := "user" +password := "password" + +c := gowebdav.NewClient(root, user, password) +``` + +After you can use this `Client` to perform actions, described below. + +**NOTICE:** we will not check errors in examples, to focus you on the `gowebdav` library's code, but you should do it in your code! + +### Create path on a WebDAV server +```go +err := c.Mkdir("folder", 0644) +``` +In case you want to create several folders you can use `c.MkdirAll()`: +```go +err := c.MkdirAll("folder/subfolder/subfolder2", 0644) +``` + +### Get files list +```go +files, _ := c.ReadDir("folder/subfolder") +for _, file := range files { + //notice that [file] has os.FileInfo type + fmt.Println(file.Name()) +} +``` + +### Download file to byte array +```go +webdavFilePath := "folder/subfolder/file.txt" +localFilePath := "/tmp/webdav/file.txt" + +bytes, _ := c.Read(webdavFilePath) +ioutil.WriteFile(localFilePath, bytes, 0644) +``` + +### Download file via reader +Also you can use `c.ReadStream()` method: +```go +webdavFilePath := "folder/subfolder/file.txt" +localFilePath := "/tmp/webdav/file.txt" + +reader, _ := c.ReadStream(webdavFilePath) + +file, _ := os.Create(localFilePath) +defer file.Close() + +io.Copy(file, reader) +``` + +### Upload file from byte array +```go +webdavFilePath := "folder/subfolder/file.txt" +localFilePath := "/tmp/webdav/file.txt" + +bytes, _ := ioutil.ReadFile(localFilePath) + +c.Write(webdavFilePath, bytes, 0644) +``` + +### Upload file via writer +```go +webdavFilePath := "folder/subfolder/file.txt" +localFilePath := "/tmp/webdav/file.txt" + +file, _ := os.Open(localFilePath) +defer file.Close() + +c.WriteStream(webdavFilePath, file, 0644) +``` + +### Get information about specified file/folder +```go +webdavFilePath := "folder/subfolder/file.txt" + +info := c.Stat(webdavFilePath) +//notice that [info] has os.FileInfo type +fmt.Println(info) +``` + +### Move file to another location +```go +oldPath := "folder/subfolder/file.txt" +newPath := "folder/subfolder/moved.txt" +isOverwrite := true + +c.Rename(oldPath, newPath, isOverwrite) +``` + +### Copy file to another location +```go +oldPath := "folder/subfolder/file.txt" +newPath := "folder/subfolder/file-copy.txt" +isOverwrite := true + +c.Copy(oldPath, newPath, isOverwrite) +``` + +### Delete file +```go +webdavFilePath := "folder/subfolder/file.txt" + +c.Remove(webdavFilePath) +``` + +## Links + +More details about WebDAV server you can read from following resources: + +* [RFC 4918 - HTTP Extensions for Web Distributed Authoring and Versioning (WebDAV)](https://tools.ietf.org/html/rfc4918) +* [RFC 5689 - Extended MKCOL for Web Distributed Authoring and Versioning (WebDAV)](https://tools.ietf.org/html/rfc5689) +* [RFC 2616 - HTTP/1.1 Status Code Definitions](http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html "HTTP/1.1 Status Code Definitions") +* [WebDav: Next Generation Collaborative Web Authoring By Lisa Dusseaul](https://books.google.de/books?isbn=0130652083 "WebDav: Next Generation Collaborative Web Authoring By Lisa Dusseault") + +**NOTICE**: RFC 2518 is obsoleted by RFC 4918 in June 2007 + +## Contributing +All contributing are welcome. If you have any suggestions or find some bug - please create an Issue to let us make this project better. We appreciate your help! + +## License +This library is distributed under the BSD 3-Clause license found in the [LICENSE](https://github.com/studio-b12/gowebdav/blob/master/LICENSE) file. +## API + +`import "github.com/studio-b12/gowebdav"` + +* [Overview](#pkg-overview) +* [Index](#pkg-index) +* [Examples](#pkg-examples) +* [Subdirectories](#pkg-subdirectories) + +### Overview +Package gowebdav is a WebDAV client library with a command line tool +included. + +### Index +* [func FixSlash(s string) string](#FixSlash) +* [func FixSlashes(s string) string](#FixSlashes) +* [func Join(path0 string, path1 string) string](#Join) +* [func PathEscape(path string) string](#PathEscape) +* [func ReadConfig(uri, netrc string) (string, string)](#ReadConfig) +* [func String(r io.Reader) string](#String) +* [type Authenticator](#Authenticator) +* [type BasicAuth](#BasicAuth) + * [func (b *BasicAuth) Authorize(req *http.Request, method string, path string)](#BasicAuth.Authorize) + * [func (b *BasicAuth) Pass() string](#BasicAuth.Pass) + * [func (b *BasicAuth) Type() string](#BasicAuth.Type) + * [func (b *BasicAuth) User() string](#BasicAuth.User) +* [type Client](#Client) + * [func NewClient(uri, user, pw string) *Client](#NewClient) + * [func (c *Client) Connect() error](#Client.Connect) + * [func (c *Client) Copy(oldpath, newpath string, overwrite bool) error](#Client.Copy) + * [func (c *Client) Mkdir(path string, _ os.FileMode) error](#Client.Mkdir) + * [func (c *Client) MkdirAll(path string, _ os.FileMode) error](#Client.MkdirAll) + * [func (c *Client) Read(path string) ([]byte, error)](#Client.Read) + * [func (c *Client) ReadDir(path string) ([]os.FileInfo, error)](#Client.ReadDir) + * [func (c *Client) ReadStream(path string) (io.ReadCloser, error)](#Client.ReadStream) + * [func (c *Client) ReadStreamRange(path string, offset, length int64) (io.ReadCloser, error)](#Client.ReadStreamRange) + * [func (c *Client) Remove(path string) error](#Client.Remove) + * [func (c *Client) RemoveAll(path string) error](#Client.RemoveAll) + * [func (c *Client) Rename(oldpath, newpath string, overwrite bool) error](#Client.Rename) + * [func (c *Client) SetHeader(key, value string)](#Client.SetHeader) + * [func (c *Client) SetInterceptor(interceptor func(method string, rq *http.Request))](#Client.SetInterceptor) + * [func (c *Client) SetTimeout(timeout time.Duration)](#Client.SetTimeout) + * [func (c *Client) SetTransport(transport http.RoundTripper)](#Client.SetTransport) + * [func (c *Client) Stat(path string) (os.FileInfo, error)](#Client.Stat) + * [func (c *Client) Write(path string, data []byte, _ os.FileMode) error](#Client.Write) + * [func (c *Client) WriteStream(path string, stream io.Reader, _ os.FileMode) error](#Client.WriteStream) +* [type DigestAuth](#DigestAuth) + * [func (d *DigestAuth) Authorize(req *http.Request, method string, path string)](#DigestAuth.Authorize) + * [func (d *DigestAuth) Pass() string](#DigestAuth.Pass) + * [func (d *DigestAuth) Type() string](#DigestAuth.Type) + * [func (d *DigestAuth) User() string](#DigestAuth.User) +* [type File](#File) + * [func (f File) ContentType() string](#File.ContentType) + * [func (f File) ETag() string](#File.ETag) + * [func (f File) IsDir() bool](#File.IsDir) + * [func (f File) ModTime() time.Time](#File.ModTime) + * [func (f File) Mode() os.FileMode](#File.Mode) + * [func (f File) Name() string](#File.Name) + * [func (f File) Path() string](#File.Path) + * [func (f File) Size() int64](#File.Size) + * [func (f File) String() string](#File.String) + * [func (f File) Sys() interface{}](#File.Sys) +* [type NoAuth](#NoAuth) + * [func (n *NoAuth) Authorize(req *http.Request, method string, path string)](#NoAuth.Authorize) + * [func (n *NoAuth) Pass() string](#NoAuth.Pass) + * [func (n *NoAuth) Type() string](#NoAuth.Type) + * [func (n *NoAuth) User() string](#NoAuth.User) + +##### Examples +* [PathEscape](#example_PathEscape) + +##### Package files +[basicAuth.go](https://github.com/studio-b12/gowebdav/blob/master/basicAuth.go) [client.go](https://github.com/studio-b12/gowebdav/blob/master/client.go) [digestAuth.go](https://github.com/studio-b12/gowebdav/blob/master/digestAuth.go) [doc.go](https://github.com/studio-b12/gowebdav/blob/master/doc.go) [file.go](https://github.com/studio-b12/gowebdav/blob/master/file.go) [netrc.go](https://github.com/studio-b12/gowebdav/blob/master/netrc.go) [requests.go](https://github.com/studio-b12/gowebdav/blob/master/requests.go) [utils.go](https://github.com/studio-b12/gowebdav/blob/master/utils.go) + +### func [FixSlash](https://github.com/studio-b12/gowebdav/blob/master/utils.go?s=707:737#L45) +``` go +func FixSlash(s string) string +``` +FixSlash appends a trailing / to our string + +### func [FixSlashes](https://github.com/studio-b12/gowebdav/blob/master/utils.go?s=859:891#L53) +``` go +func FixSlashes(s string) string +``` +FixSlashes appends and prepends a / if they are missing + +### func [Join](https://github.com/studio-b12/gowebdav/blob/master/utils.go?s=992:1036#L62) +``` go +func Join(path0 string, path1 string) string +``` +Join joins two paths + +### func [PathEscape](https://github.com/studio-b12/gowebdav/blob/master/utils.go?s=506:541#L36) +``` go +func PathEscape(path string) string +``` +PathEscape escapes all segments of a given path + +### func [ReadConfig](https://github.com/studio-b12/gowebdav/blob/master/netrc.go?s=428:479#L27) +``` go +func ReadConfig(uri, netrc string) (string, string) +``` +ReadConfig reads login and password configuration from ~/.netrc +machine foo.com login username password 123456 + +### func [String](https://github.com/studio-b12/gowebdav/blob/master/utils.go?s=1166:1197#L67) +``` go +func String(r io.Reader) string +``` +String pulls a string out of our io.Reader + +### type [Authenticator](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=388:507#L29) +``` go +type Authenticator interface { + Type() string + User() string + Pass() string + Authorize(*http.Request, string, string) +} +``` +Authenticator stub + +### type [BasicAuth](https://github.com/studio-b12/gowebdav/blob/master/basicAuth.go?s=106:157#L9) +``` go +type BasicAuth struct { + // contains filtered or unexported fields +} +``` +BasicAuth structure holds our credentials + +#### func (\*BasicAuth) [Authorize](https://github.com/studio-b12/gowebdav/blob/master/basicAuth.go?s=473:549#L30) +``` go +func (b *BasicAuth) Authorize(req *http.Request, method string, path string) +``` +Authorize the current request + +#### func (\*BasicAuth) [Pass](https://github.com/studio-b12/gowebdav/blob/master/basicAuth.go?s=388:421#L25) +``` go +func (b *BasicAuth) Pass() string +``` +Pass holds the BasicAuth password + +#### func (\*BasicAuth) [Type](https://github.com/studio-b12/gowebdav/blob/master/basicAuth.go?s=201:234#L15) +``` go +func (b *BasicAuth) Type() string +``` +Type identifies the BasicAuthenticator + +#### func (\*BasicAuth) [User](https://github.com/studio-b12/gowebdav/blob/master/basicAuth.go?s=297:330#L20) +``` go +func (b *BasicAuth) User() string +``` +User holds the BasicAuth username + +### type [Client](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=172:364#L18) +``` go +type Client struct { + // contains filtered or unexported fields +} +``` +Client defines our structure + +#### func [NewClient](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=1019:1063#L62) +``` go +func NewClient(uri, user, pw string) *Client +``` +NewClient creates a new instance of client + +#### func (\*Client) [Connect](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=1843:1875#L87) +``` go +func (c *Client) Connect() error +``` +Connect connects to our dav server + +#### func (\*Client) [Copy](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=6702:6770#L313) +``` go +func (c *Client) Copy(oldpath, newpath string, overwrite bool) error +``` +Copy copies a file from A to B + +#### func (\*Client) [Mkdir](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=5793:5849#L272) +``` go +func (c *Client) Mkdir(path string, _ os.FileMode) error +``` +Mkdir makes a directory + +#### func (\*Client) [MkdirAll](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=6028:6087#L283) +``` go +func (c *Client) MkdirAll(path string, _ os.FileMode) error +``` +MkdirAll like mkdir -p, but for webdav + +#### func (\*Client) [Read](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=6876:6926#L318) +``` go +func (c *Client) Read(path string) ([]byte, error) +``` +Read reads the contents of a remote file + +#### func (\*Client) [ReadDir](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=2869:2929#L130) +``` go +func (c *Client) ReadDir(path string) ([]os.FileInfo, error) +``` +ReadDir reads the contents of a remote directory + +#### func (\*Client) [ReadStream](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=7237:7300#L336) +``` go +func (c *Client) ReadStream(path string) (io.ReadCloser, error) +``` +ReadStream reads the stream for a given path + +#### func (\*Client) [ReadStreamRange](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=8049:8139#L358) +``` go +func (c *Client) ReadStreamRange(path string, offset, length int64) (io.ReadCloser, error) +``` +ReadStreamRange reads the stream representing a subset of bytes for a given path, +utilizing HTTP Range Requests if the server supports it. +The range is expressed as offset from the start of the file and length, for example +offset=10, length=10 will return bytes 10 through 19. + +If the server does not support partial content requests and returns full content instead, +this function will emulate the behavior by skipping `offset` bytes and limiting the result +to `length`. + +#### func (\*Client) [Remove](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=5299:5341#L249) +``` go +func (c *Client) Remove(path string) error +``` +Remove removes a remote file + +#### func (\*Client) [RemoveAll](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=5407:5452#L254) +``` go +func (c *Client) RemoveAll(path string) error +``` +RemoveAll removes remote files + +#### func (\*Client) [Rename](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=6536:6606#L308) +``` go +func (c *Client) Rename(oldpath, newpath string, overwrite bool) error +``` +Rename moves a file from A to B + +#### func (\*Client) [SetHeader](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=1235:1280#L67) +``` go +func (c *Client) SetHeader(key, value string) +``` +SetHeader lets us set arbitrary headers for a given client + +#### func (\*Client) [SetInterceptor](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=1387:1469#L72) +``` go +func (c *Client) SetInterceptor(interceptor func(method string, rq *http.Request)) +``` +SetInterceptor lets us set an arbitrary interceptor for a given client + +#### func (\*Client) [SetTimeout](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=1571:1621#L77) +``` go +func (c *Client) SetTimeout(timeout time.Duration) +``` +SetTimeout exposes the ability to set a time limit for requests + +#### func (\*Client) [SetTransport](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=1714:1772#L82) +``` go +func (c *Client) SetTransport(transport http.RoundTripper) +``` +SetTransport exposes the ability to define custom transports + +#### func (\*Client) [Stat](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=4255:4310#L197) +``` go +func (c *Client) Stat(path string) (os.FileInfo, error) +``` +Stat returns the file stats for a specified path + +#### func (\*Client) [Write](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=9051:9120#L388) +``` go +func (c *Client) Write(path string, data []byte, _ os.FileMode) error +``` +Write writes data to a given path + +#### func (\*Client) [WriteStream](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=9476:9556#L411) +``` go +func (c *Client) WriteStream(path string, stream io.Reader, _ os.FileMode) error +``` +WriteStream writes a stream + +### type [DigestAuth](https://github.com/studio-b12/gowebdav/blob/master/digestAuth.go?s=157:254#L14) +``` go +type DigestAuth struct { + // contains filtered or unexported fields +} +``` +DigestAuth structure holds our credentials + +#### func (\*DigestAuth) [Authorize](https://github.com/studio-b12/gowebdav/blob/master/digestAuth.go?s=577:654#L36) +``` go +func (d *DigestAuth) Authorize(req *http.Request, method string, path string) +``` +Authorize the current request + +#### func (\*DigestAuth) [Pass](https://github.com/studio-b12/gowebdav/blob/master/digestAuth.go?s=491:525#L31) +``` go +func (d *DigestAuth) Pass() string +``` +Pass holds the DigestAuth password + +#### func (\*DigestAuth) [Type](https://github.com/studio-b12/gowebdav/blob/master/digestAuth.go?s=299:333#L21) +``` go +func (d *DigestAuth) Type() string +``` +Type identifies the DigestAuthenticator + +#### func (\*DigestAuth) [User](https://github.com/studio-b12/gowebdav/blob/master/digestAuth.go?s=398:432#L26) +``` go +func (d *DigestAuth) User() string +``` +User holds the DigestAuth username + +### type [File](https://github.com/studio-b12/gowebdav/blob/master/file.go?s=93:253#L10) +``` go +type File struct { + // contains filtered or unexported fields +} +``` +File is our structure for a given file + +#### func (File) [ContentType](https://github.com/studio-b12/gowebdav/blob/master/file.go?s=476:510#L31) +``` go +func (f File) ContentType() string +``` +ContentType returns the content type of a file + +#### func (File) [ETag](https://github.com/studio-b12/gowebdav/blob/master/file.go?s=929:956#L56) +``` go +func (f File) ETag() string +``` +ETag returns the ETag of a file + +#### func (File) [IsDir](https://github.com/studio-b12/gowebdav/blob/master/file.go?s=1035:1061#L61) +``` go +func (f File) IsDir() bool +``` +IsDir let us see if a given file is a directory or not + +#### func (File) [ModTime](https://github.com/studio-b12/gowebdav/blob/master/file.go?s=836:869#L51) +``` go +func (f File) ModTime() time.Time +``` +ModTime returns the modified time of a file + +#### func (File) [Mode](https://github.com/studio-b12/gowebdav/blob/master/file.go?s=665:697#L41) +``` go +func (f File) Mode() os.FileMode +``` +Mode will return the mode of a given file + +#### func (File) [Name](https://github.com/studio-b12/gowebdav/blob/master/file.go?s=378:405#L26) +``` go +func (f File) Name() string +``` +Name returns the name of a file + +#### func (File) [Path](https://github.com/studio-b12/gowebdav/blob/master/file.go?s=295:322#L21) +``` go +func (f File) Path() string +``` +Path returns the full path of a file + +#### func (File) [Size](https://github.com/studio-b12/gowebdav/blob/master/file.go?s=573:599#L36) +``` go +func (f File) Size() int64 +``` +Size returns the size of a file + +#### func (File) [String](https://github.com/studio-b12/gowebdav/blob/master/file.go?s=1183:1212#L71) +``` go +func (f File) String() string +``` +String lets us see file information + +#### func (File) [Sys](https://github.com/studio-b12/gowebdav/blob/master/file.go?s=1095:1126#L66) +``` go +func (f File) Sys() interface{} +``` +Sys ???? + +### type [NoAuth](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=551:599#L37) +``` go +type NoAuth struct { + // contains filtered or unexported fields +} +``` +NoAuth structure holds our credentials + +#### func (\*NoAuth) [Authorize](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=894:967#L58) +``` go +func (n *NoAuth) Authorize(req *http.Request, method string, path string) +``` +Authorize the current request + +#### func (\*NoAuth) [Pass](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=812:842#L53) +``` go +func (n *NoAuth) Pass() string +``` +Pass returns the current password + +#### func (\*NoAuth) [Type](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=638:668#L43) +``` go +func (n *NoAuth) Type() string +``` +Type identifies the authenticator + +#### func (\*NoAuth) [User](https://github.com/studio-b12/gowebdav/blob/master/client.go?s=724:754#L48) +``` go +func (n *NoAuth) User() string +``` +User returns the current user + +- - - +Generated by [godoc2md](http://godoc.org/github.com/davecheney/godoc2md) diff --git a/pkg/gowebdav/basicAuth.go b/pkg/gowebdav/basicAuth.go new file mode 100644 index 0000000000000000000000000000000000000000..bdb86da580c06f23eed66e7a5745e10ba9536c0d --- /dev/null +++ b/pkg/gowebdav/basicAuth.go @@ -0,0 +1,34 @@ +package gowebdav + +import ( + "encoding/base64" + "net/http" +) + +// BasicAuth structure holds our credentials +type BasicAuth struct { + user string + pw string +} + +// Type identifies the BasicAuthenticator +func (b *BasicAuth) Type() string { + return "BasicAuth" +} + +// User holds the BasicAuth username +func (b *BasicAuth) User() string { + return b.user +} + +// Pass holds the BasicAuth password +func (b *BasicAuth) Pass() string { + return b.pw +} + +// Authorize the current request +func (b *BasicAuth) Authorize(req *http.Request, method string, path string) { + a := b.user + ":" + b.pw + auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(a)) + req.Header.Set("Authorization", auth) +} diff --git a/pkg/gowebdav/client.go b/pkg/gowebdav/client.go new file mode 100644 index 0000000000000000000000000000000000000000..2fca0b7f43db9424ab3e868b6f8022ada9d89e36 --- /dev/null +++ b/pkg/gowebdav/client.go @@ -0,0 +1,484 @@ +package gowebdav + +import ( + "bytes" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "os" + pathpkg "path" + "strings" + "sync" + "time" +) + +// Client defines our structure +type Client struct { + root string + headers http.Header + interceptor func(method string, rq *http.Request) + c *http.Client + + authMutex sync.Mutex + auth Authenticator +} + +// Authenticator stub +type Authenticator interface { + Type() string + User() string + Pass() string + Authorize(*http.Request, string, string) +} + +// NoAuth structure holds our credentials +type NoAuth struct { + user string + pw string +} + +// Type identifies the authenticator +func (n *NoAuth) Type() string { + return "NoAuth" +} + +// User returns the current user +func (n *NoAuth) User() string { + return n.user +} + +// Pass returns the current password +func (n *NoAuth) Pass() string { + return n.pw +} + +// Authorize the current request +func (n *NoAuth) Authorize(req *http.Request, method string, path string) { +} + +// NewClient creates a new instance of client +func NewClient(uri, user, pw string) *Client { + return &Client{FixSlash(uri), make(http.Header), nil, &http.Client{}, sync.Mutex{}, &NoAuth{user, pw}} +} + +// SetHeader lets us set arbitrary headers for a given client +func (c *Client) SetHeader(key, value string) { + c.headers.Add(key, value) +} + +// SetInterceptor lets us set an arbitrary interceptor for a given client +func (c *Client) SetInterceptor(interceptor func(method string, rq *http.Request)) { + c.interceptor = interceptor +} + +// SetTimeout exposes the ability to set a time limit for requests +func (c *Client) SetTimeout(timeout time.Duration) { + c.c.Timeout = timeout +} + +// SetTransport exposes the ability to define custom transports +func (c *Client) SetTransport(transport http.RoundTripper) { + c.c.Transport = transport +} + +// SetJar exposes the ability to set a cookie jar to the client. +func (c *Client) SetJar(jar http.CookieJar) { + c.c.Jar = jar +} + +// Connect connects to our dav server +func (c *Client) Connect() error { + rs, err := c.options("/") + if err != nil { + return err + } + + err = rs.Body.Close() + if err != nil { + return err + } + + if rs.StatusCode != 200 { + return newPathError("Connect", c.root, rs.StatusCode) + } + + return nil +} + +type props struct { + Status string `xml:"DAV: status"` + Name string `xml:"DAV: prop>displayname,omitempty"` + Type xml.Name `xml:"DAV: prop>resourcetype>collection,omitempty"` + Size string `xml:"DAV: prop>getcontentlength,omitempty"` + ContentType string `xml:"DAV: prop>getcontenttype,omitempty"` + ETag string `xml:"DAV: prop>getetag,omitempty"` + Modified string `xml:"DAV: prop>getlastmodified,omitempty"` +} + +type response struct { + Href string `xml:"DAV: href"` + Props []props `xml:"DAV: propstat"` +} + +func getProps(r *response, status string) *props { + for _, prop := range r.Props { + if strings.Contains(prop.Status, status) { + return &prop + } + } + return nil +} + +// ReadDir reads the contents of a remote directory +func (c *Client) ReadDir(path string) ([]os.FileInfo, error) { + path = FixSlashes(path) + files := make([]os.FileInfo, 0) + skipSelf := true + parse := func(resp interface{}) error { + r := resp.(*response) + + if skipSelf { + skipSelf = false + if p := getProps(r, "200"); p != nil && p.Type.Local == "collection" { + r.Props = nil + return nil + } + return newPathError("ReadDir", path, 405) + } + + if p := getProps(r, "200"); p != nil { + f := new(File) + if ps, err := url.PathUnescape(r.Href); err == nil { + f.name = pathpkg.Base(ps) + } else { + f.name = p.Name + } + f.path = path + f.name + f.modified = parseModified(&p.Modified) + f.etag = p.ETag + f.contentType = p.ContentType + + if p.Type.Local == "collection" { + f.path += "/" + f.size = 0 + f.isdir = true + } else { + f.size = parseInt64(&p.Size) + f.isdir = false + } + + files = append(files, *f) + } + + r.Props = nil + return nil + } + + err := c.propfind(path, false, + ` + + + + + + + + + `, + &response{}, + parse) + + if err != nil { + if _, ok := err.(*os.PathError); !ok { + err = newPathErrorErr("ReadDir", path, err) + } + } + return files, err +} + +// Stat returns the file stats for a specified path +func (c *Client) Stat(path string) (os.FileInfo, error) { + var f *File + parse := func(resp interface{}) error { + r := resp.(*response) + if p := getProps(r, "200"); p != nil && f == nil { + f = new(File) + f.name = p.Name + f.path = path + f.etag = p.ETag + f.contentType = p.ContentType + + if p.Type.Local == "collection" { + if !strings.HasSuffix(f.path, "/") { + f.path += "/" + } + f.size = 0 + f.modified = time.Unix(0, 0) + f.isdir = true + } else { + f.size = parseInt64(&p.Size) + f.modified = parseModified(&p.Modified) + f.isdir = false + } + } + + r.Props = nil + return nil + } + + err := c.propfind(path, true, + ` + + + + + + + + + `, + &response{}, + parse) + + if err != nil { + if _, ok := err.(*os.PathError); !ok { + err = newPathErrorErr("ReadDir", path, err) + } + } + return f, err +} + +// Remove removes a remote file +func (c *Client) Remove(path string) error { + return c.RemoveAll(path) +} + +// RemoveAll removes remote files +func (c *Client) RemoveAll(path string) error { + rs, err := c.req("DELETE", path, nil, nil) + if err != nil { + return newPathError("Remove", path, 400) + } + err = rs.Body.Close() + if err != nil { + return err + } + + if rs.StatusCode == 200 || rs.StatusCode == 204 || rs.StatusCode == 404 { + return nil + } + + return newPathError("Remove", path, rs.StatusCode) +} + +// Mkdir makes a directory +func (c *Client) Mkdir(path string, _ os.FileMode) (err error) { + path = FixSlashes(path) + status, err := c.mkcol(path) + if err != nil { + return + } + if status == 201 { + return nil + } + + return newPathError("Mkdir", path, status) +} + +// MkdirAll like mkdir -p, but for webdav +func (c *Client) MkdirAll(path string, _ os.FileMode) (err error) { + path = FixSlashes(path) + status, err := c.mkcol(path) + if err != nil { + return + } + if status == 201 { + return nil + } + if status == 409 { + paths := strings.Split(path, "/") + sub := "/" + for _, e := range paths { + if e == "" { + continue + } + sub += e + "/" + status, err = c.mkcol(sub) + if err != nil { + return + } + if status != 201 { + return newPathError("MkdirAll", sub, status) + } + } + return nil + } + + return newPathError("MkdirAll", path, status) +} + +// Rename moves a file from A to B +func (c *Client) Rename(oldpath, newpath string, overwrite bool) error { + return c.copymove("MOVE", oldpath, newpath, overwrite) +} + +// Copy copies a file from A to B +func (c *Client) Copy(oldpath, newpath string, overwrite bool) error { + return c.copymove("COPY", oldpath, newpath, overwrite) +} + +// Read reads the contents of a remote file +func (c *Client) Read(path string) ([]byte, error) { + var stream io.ReadCloser + var err error + + if stream, _, err = c.ReadStream(path, nil); err != nil { + return nil, err + } + defer stream.Close() + + buf := new(bytes.Buffer) + _, err = buf.ReadFrom(stream) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func (c *Client) Link(path string) (string, http.Header, error) { + method := "GET" + u := PathEscape(Join(c.root, path)) + r, err := http.NewRequest(method, u, nil) + + if err != nil { + return "", nil, newPathErrorErr("Link", path, err) + } + + if c.c.Jar != nil { + for _, cookie := range c.c.Jar.Cookies(r.URL) { + r.AddCookie(cookie) + } + } + for k, vals := range c.headers { + for _, v := range vals { + r.Header.Add(k, v) + } + } + + c.authMutex.Lock() + auth := c.auth + c.authMutex.Unlock() + + auth.Authorize(r, method, path) + + if c.interceptor != nil { + c.interceptor(method, r) + } + return r.URL.String(), r.Header, nil +} + +// ReadStream reads the stream for a given path +func (c *Client) ReadStream(path string, callback func(rq *http.Request)) (io.ReadCloser, http.Header, error) { + rs, err := c.req("GET", path, nil, callback) + if err != nil { + return nil, nil, newPathErrorErr("ReadStream", path, err) + } + + if rs.StatusCode < 400 { + return rs.Body, rs.Header, nil + } + + rs.Body.Close() + return nil, nil, newPathError("ReadStream", path, rs.StatusCode) +} + +// ReadStreamRange reads the stream representing a subset of bytes for a given path, +// utilizing HTTP Range Requests if the server supports it. +// The range is expressed as offset from the start of the file and length, for example +// offset=10, length=10 will return bytes 10 through 19. +// +// If the server does not support partial content requests and returns full content instead, +// this function will emulate the behavior by skipping `offset` bytes and limiting the result +// to `length`. +func (c *Client) ReadStreamRange(path string, offset, length int64) (io.ReadCloser, error) { + rs, err := c.req("GET", path, nil, func(r *http.Request) { + r.Header.Add("Range", fmt.Sprintf("bytes=%v-%v", offset, offset+length-1)) + }) + if err != nil { + return nil, newPathErrorErr("ReadStreamRange", path, err) + } + + if rs.StatusCode == http.StatusPartialContent { + // server supported partial content, return as-is. + return rs.Body, nil + } + + // server returned success, but did not support partial content, so we have the whole + // stream in rs.Body + if rs.StatusCode == 200 { + // discard first 'offset' bytes. + if _, err := io.Copy(io.Discard, io.LimitReader(rs.Body, offset)); err != nil { + return nil, newPathErrorErr("ReadStreamRange", path, err) + } + + // return a io.ReadCloser that is limited to `length` bytes. + return &limitedReadCloser{rs.Body, int(length)}, nil + } + + rs.Body.Close() + return nil, newPathError("ReadStream", path, rs.StatusCode) +} + +// Write writes data to a given path +func (c *Client) Write(path string, data []byte, _ os.FileMode) (err error) { + s, err := c.put(path, bytes.NewReader(data), nil) + if err != nil { + return + } + + switch s { + + case 200, 201, 204: + return nil + + case 409: + err = c.createParentCollection(path) + if err != nil { + return + } + + s, err = c.put(path, bytes.NewReader(data), nil) + if err != nil { + return + } + if s == 200 || s == 201 || s == 204 { + return + } + } + + return newPathError("Write", path, s) +} + +// WriteStream writes a stream +func (c *Client) WriteStream(path string, stream io.Reader, _ os.FileMode, callback func(r *http.Request)) (err error) { + + err = c.createParentCollection(path) + if err != nil { + return err + } + + s, err := c.put(path, stream, callback) + if err != nil { + return err + } + + switch s { + case 200, 201, 204: + return nil + + default: + return newPathError("WriteStream", path, s) + } +} diff --git a/pkg/gowebdav/cmd/gowebdav/README.md b/pkg/gowebdav/cmd/gowebdav/README.md new file mode 100644 index 0000000000000000000000000000000000000000..30e1d4ca7a0f19baffbfff928e07f5613ed0e818 --- /dev/null +++ b/pkg/gowebdav/cmd/gowebdav/README.md @@ -0,0 +1,103 @@ +# Description +Command line tool for [gowebdav](https://github.com/studio-b12/gowebdav) library. + +# Prerequisites +## Software +* **OS**: all, which are supported by `Golang` +* **Golang**: version 1.x +* **Git**: version 2.14.2 at higher (required to install via `go get`) + +# Install +```sh +go get -u github.com/studio-b12/gowebdav/cmd/gowebdav +``` + +# Usage +It is recommended to set following environment variables to improve your experience with this tool: +* `ROOT` is an URL of target WebDAV server (e.g. `https://webdav.mydomain.me/user_root_folder`) +* `USER` is a login to connect to specified server (e.g. `user`) +* `PASSWORD` is a password to connect to specified server (e.g. `p@s$w0rD`) + +In following examples we suppose that: +* environment variable `ROOT` is set to `https://webdav.mydomain.me/ufolder` +* environment variable `USER` is set to `user` +* environment variable `PASSWORD` is set `p@s$w0rD` +* folder `/ufolder/temp` exists on the server +* file `/ufolder/temp/file.txt` exists on the server +* file `/ufolder/temp/document.rtf` exists on the server +* file `/tmp/webdav/to_upload.txt` exists on the local machine +* folder `/tmp/webdav/` is used to download files from the server + +## Examples + +#### Get content of specified folder +```sh +gowebdav -X LS temp +``` + +#### Get info about file/folder +```sh +gowebdav -X STAT temp +gowebdav -X STAT temp/file.txt +``` + +#### Create folder on the remote server +```sh +gowebdav -X MKDIR temp2 +gowebdav -X MKDIRALL all/folders/which-you-want/to_create +``` + +#### Download file +```sh +gowebdav -X GET temp/document.rtf /tmp/webdav/document.rtf +``` + +You may do not specify target local path, in this case file will be downloaded to the current folder with the + +#### Upload file +```sh +gowebdav -X PUT temp/uploaded.txt /tmp/webdav/to_upload.txt +``` + +#### Move file on the remote server +```sh +gowebdav -X MV temp/file.txt temp/moved_file.txt +``` + +#### Copy file to another location +```sh +gowebdav -X MV temp/file.txt temp/file-copy.txt +``` + +#### Delete file from the remote server +```sh +gowebdav -X DEL temp/file.txt +``` + +# Wrapper script + +You can create wrapper script for your server (via `$EDITOR ./dav && chmod a+x ./dav`) and add following content to it: +```sh +#!/bin/sh + +ROOT="https://my.dav.server/" \ +USER="foo" \ +PASSWORD="$(pass dav/foo@my.dav.server)" \ +gowebdav $@ +``` + +It allows you to use [pass](https://www.passwordstore.org/ "the standard unix password manager") or similar tools to retrieve the password. + +## Examples + +Using the `dav` wrapper: + +```sh +$ ./dav -X LS / + +$ echo hi dav! > hello && ./dav -X PUT /hello +$ ./dav -X STAT /hello +$ ./dav -X PUT /hello_dav hello +$ ./dav -X GET /hello_dav +$ ./dav -X GET /hello_dav hello.txt +``` \ No newline at end of file diff --git a/pkg/gowebdav/cmd/gowebdav/main.go b/pkg/gowebdav/cmd/gowebdav/main.go new file mode 100644 index 0000000000000000000000000000000000000000..0164496f53b2efd69c873a6384f9bd929bef58f8 --- /dev/null +++ b/pkg/gowebdav/cmd/gowebdav/main.go @@ -0,0 +1,263 @@ +package main + +import ( + "errors" + "flag" + "fmt" + "io" + "io/fs" + "os" + "os/user" + "path" + "path/filepath" + "runtime" + "strings" + + d "github.com/alist-org/alist/v3/pkg/gowebdav" +) + +func main() { + root := flag.String("root", os.Getenv("ROOT"), "WebDAV Endpoint [ENV.ROOT]") + user := flag.String("user", os.Getenv("USER"), "User [ENV.USER]") + password := flag.String("pw", os.Getenv("PASSWORD"), "Password [ENV.PASSWORD]") + netrc := flag.String("netrc-file", filepath.Join(getHome(), ".netrc"), "read login from netrc file") + method := flag.String("X", "", `Method: + LS + STAT + + MKDIR + MKDIRALL + + GET [] + PUT [] + + MV + CP + + DEL + `) + flag.Parse() + + if *root == "" { + fail("Set WebDAV ROOT") + } + + if argsLength := len(flag.Args()); argsLength == 0 || argsLength > 2 { + fail("Unsupported arguments") + } + + if *password == "" { + if u, p := d.ReadConfig(*root, *netrc); u != "" && p != "" { + user = &u + password = &p + } + } + + c := d.NewClient(*root, *user, *password) + + cmd := getCmd(*method) + + if e := cmd(c, flag.Arg(0), flag.Arg(1)); e != nil { + fail(e) + } +} + +func fail(err interface{}) { + if err != nil { + fmt.Println(err) + } + os.Exit(-1) +} + +func getHome() string { + u, e := user.Current() + if e != nil { + return os.Getenv("HOME") + } + + if u != nil { + return u.HomeDir + } + + switch runtime.GOOS { + case "windows": + return "" + default: + return "~/" + } +} + +func getCmd(method string) func(c *d.Client, p0, p1 string) error { + switch strings.ToUpper(method) { + case "LS", "LIST", "PROPFIND": + return cmdLs + + case "STAT": + return cmdStat + + case "GET", "PULL", "READ": + return cmdGet + + case "DELETE", "RM", "DEL": + return cmdRm + + case "MKCOL", "MKDIR": + return cmdMkdir + + case "MKCOLALL", "MKDIRALL", "MKDIRP": + return cmdMkdirAll + + case "RENAME", "MV", "MOVE": + return cmdMv + + case "COPY", "CP": + return cmdCp + + case "PUT", "PUSH", "WRITE": + return cmdPut + + default: + return func(c *d.Client, p0, p1 string) (err error) { + return errors.New("Unsupported method: " + method) + } + } +} + +func cmdLs(c *d.Client, p0, _ string) (err error) { + files, err := c.ReadDir(p0) + if err == nil { + fmt.Println(fmt.Sprintf("ReadDir: '%s' entries: %d ", p0, len(files))) + for _, f := range files { + fmt.Println(f) + } + } + return +} + +func cmdStat(c *d.Client, p0, _ string) (err error) { + file, err := c.Stat(p0) + if err == nil { + fmt.Println(file) + } + return +} + +func cmdGet(c *d.Client, p0, p1 string) (err error) { + bytes, err := c.Read(p0) + if err == nil { + if p1 == "" { + p1 = filepath.Join(".", p0) + } + err = writeFile(p1, bytes, 0644) + if err == nil { + fmt.Println(fmt.Sprintf("Written %d bytes to: %s", len(bytes), p1)) + } + } + return +} + +func cmdRm(c *d.Client, p0, _ string) (err error) { + if err = c.Remove(p0); err == nil { + fmt.Println("Remove: " + p0) + } + return +} + +func cmdMkdir(c *d.Client, p0, _ string) (err error) { + if err = c.Mkdir(p0, 0755); err == nil { + fmt.Println("Mkdir: " + p0) + } + return +} + +func cmdMkdirAll(c *d.Client, p0, _ string) (err error) { + if err = c.MkdirAll(p0, 0755); err == nil { + fmt.Println("MkdirAll: " + p0) + } + return +} + +func cmdMv(c *d.Client, p0, p1 string) (err error) { + if err = c.Rename(p0, p1, true); err == nil { + fmt.Println("Rename: " + p0 + " -> " + p1) + } + return +} + +func cmdCp(c *d.Client, p0, p1 string) (err error) { + if err = c.Copy(p0, p1, true); err == nil { + fmt.Println("Copy: " + p0 + " -> " + p1) + } + return +} + +func cmdPut(c *d.Client, p0, p1 string) (err error) { + if p1 == "" { + p1 = path.Join(".", p0) + } else { + var fi fs.FileInfo + fi, err = c.Stat(p0) + if err != nil && !d.IsErrNotFound(err) { + return + } + if !d.IsErrNotFound(err) && fi.IsDir() { + p0 = path.Join(p0, p1) + } + } + + stream, err := getStream(p1) + if err != nil { + return + } + defer stream.Close() + + if err = c.WriteStream(p0, stream, 0644, nil); err == nil { + fmt.Println("Put: " + p1 + " -> " + p0) + } + return +} + +func writeFile(path string, bytes []byte, mode os.FileMode) error { + parent := filepath.Dir(path) + if _, e := os.Stat(parent); os.IsNotExist(e) { + if e := os.MkdirAll(parent, os.ModePerm); e != nil { + return e + } + } + + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + _, err = f.Write(bytes) + return err +} + +func getStream(pathOrString string) (io.ReadCloser, error) { + + fi, err := os.Stat(pathOrString) + if err != nil { + return nil, err + } + + if fi.IsDir() { + return nil, &os.PathError{ + Op: "Open", + Path: pathOrString, + Err: errors.New("Path: '" + pathOrString + "' is a directory"), + } + } + + f, err := os.Open(pathOrString) + if err == nil { + return f, nil + } + + return nil, &os.PathError{ + Op: "Open", + Path: pathOrString, + Err: err, + } +} diff --git a/pkg/gowebdav/digestAuth.go b/pkg/gowebdav/digestAuth.go new file mode 100644 index 0000000000000000000000000000000000000000..4a5eb62f2fbb551c47f3398d340ad83a2bde8e8e --- /dev/null +++ b/pkg/gowebdav/digestAuth.go @@ -0,0 +1,146 @@ +package gowebdav + +import ( + "crypto/md5" + "crypto/rand" + "encoding/hex" + "fmt" + "io" + "net/http" + "strings" +) + +// DigestAuth structure holds our credentials +type DigestAuth struct { + user string + pw string + digestParts map[string]string +} + +// Type identifies the DigestAuthenticator +func (d *DigestAuth) Type() string { + return "DigestAuth" +} + +// User holds the DigestAuth username +func (d *DigestAuth) User() string { + return d.user +} + +// Pass holds the DigestAuth password +func (d *DigestAuth) Pass() string { + return d.pw +} + +// Authorize the current request +func (d *DigestAuth) Authorize(req *http.Request, method string, path string) { + d.digestParts["uri"] = path + d.digestParts["method"] = method + d.digestParts["username"] = d.user + d.digestParts["password"] = d.pw + req.Header.Set("Authorization", getDigestAuthorization(d.digestParts)) +} + +func digestParts(resp *http.Response) map[string]string { + result := map[string]string{} + if len(resp.Header["Www-Authenticate"]) > 0 { + wantedHeaders := []string{"nonce", "realm", "qop", "opaque", "algorithm", "entityBody"} + responseHeaders := strings.Split(resp.Header["Www-Authenticate"][0], ",") + for _, r := range responseHeaders { + for _, w := range wantedHeaders { + if strings.Contains(r, w) { + result[w] = strings.Trim( + strings.SplitN(r, `=`, 2)[1], + `"`, + ) + } + } + } + } + return result +} + +func getMD5(text string) string { + hasher := md5.New() + hasher.Write([]byte(text)) + return hex.EncodeToString(hasher.Sum(nil)) +} + +func getCnonce() string { + b := make([]byte, 8) + io.ReadFull(rand.Reader, b) + return fmt.Sprintf("%x", b)[:16] +} + +func getDigestAuthorization(digestParts map[string]string) string { + d := digestParts + // These are the correct ha1 and ha2 for qop=auth. We should probably check for other types of qop. + + var ( + ha1 string + ha2 string + nonceCount = 00000001 + cnonce = getCnonce() + response string + ) + + // 'ha1' value depends on value of "algorithm" field + switch d["algorithm"] { + case "MD5", "": + ha1 = getMD5(d["username"] + ":" + d["realm"] + ":" + d["password"]) + case "MD5-sess": + ha1 = getMD5( + fmt.Sprintf("%s:%v:%s", + getMD5(d["username"]+":"+d["realm"]+":"+d["password"]), + nonceCount, + cnonce, + ), + ) + } + + // 'ha2' value depends on value of "qop" field + switch d["qop"] { + case "auth", "": + ha2 = getMD5(d["method"] + ":" + d["uri"]) + case "auth-int": + if d["entityBody"] != "" { + ha2 = getMD5(d["method"] + ":" + d["uri"] + ":" + getMD5(d["entityBody"])) + } + } + + // 'response' value depends on value of "qop" field + switch d["qop"] { + case "": + response = getMD5( + fmt.Sprintf("%s:%s:%s", + ha1, + d["nonce"], + ha2, + ), + ) + case "auth", "auth-int": + response = getMD5( + fmt.Sprintf("%s:%s:%v:%s:%s:%s", + ha1, + d["nonce"], + nonceCount, + cnonce, + d["qop"], + ha2, + ), + ) + } + + authorization := fmt.Sprintf(`Digest username="%s", realm="%s", nonce="%s", uri="%s", nc=%v, cnonce="%s", response="%s"`, + d["username"], d["realm"], d["nonce"], d["uri"], nonceCount, cnonce, response) + + if d["qop"] != "" { + authorization += fmt.Sprintf(`, qop=%s`, d["qop"]) + } + + if d["opaque"] != "" { + authorization += fmt.Sprintf(`, opaque="%s"`, d["opaque"]) + } + + return authorization +} diff --git a/pkg/gowebdav/doc.go b/pkg/gowebdav/doc.go new file mode 100644 index 0000000000000000000000000000000000000000..e47d5eee25b8fd48619089dc32d110d9c237bc55 --- /dev/null +++ b/pkg/gowebdav/doc.go @@ -0,0 +1,3 @@ +// Package gowebdav is a WebDAV client library with a command line tool +// included. +package gowebdav diff --git a/pkg/gowebdav/errors.go b/pkg/gowebdav/errors.go new file mode 100644 index 0000000000000000000000000000000000000000..bbf1e929eb84f0009424abecc0fb6b3fc3ae9be3 --- /dev/null +++ b/pkg/gowebdav/errors.go @@ -0,0 +1,49 @@ +package gowebdav + +import ( + "fmt" + "os" +) + +// StatusError implements error and wraps +// an erroneous status code. +type StatusError struct { + Status int +} + +func (se StatusError) Error() string { + return fmt.Sprintf("%d", se.Status) +} + +// IsErrCode returns true if the given error +// is an os.PathError wrapping a StatusError +// with the given status code. +func IsErrCode(err error, code int) bool { + if pe, ok := err.(*os.PathError); ok { + se, ok := pe.Err.(StatusError) + return ok && se.Status == code + } + return false +} + +// IsErrNotFound is shorthand for IsErrCode +// for status 404. +func IsErrNotFound(err error) bool { + return IsErrCode(err, 404) +} + +func newPathError(op string, path string, statusCode int) error { + return &os.PathError{ + Op: op, + Path: path, + Err: StatusError{statusCode}, + } +} + +func newPathErrorErr(op string, path string, err error) error { + return &os.PathError{ + Op: op, + Path: path, + Err: err, + } +} diff --git a/pkg/gowebdav/file.go b/pkg/gowebdav/file.go new file mode 100644 index 0000000000000000000000000000000000000000..ae2303fc8e0fe1d1470d0e5a2ea38c2a0b151441 --- /dev/null +++ b/pkg/gowebdav/file.go @@ -0,0 +1,77 @@ +package gowebdav + +import ( + "fmt" + "os" + "time" +) + +// File is our structure for a given file +type File struct { + path string + name string + contentType string + size int64 + modified time.Time + etag string + isdir bool +} + +// Path returns the full path of a file +func (f File) Path() string { + return f.path +} + +// Name returns the name of a file +func (f File) Name() string { + return f.name +} + +// ContentType returns the content type of a file +func (f File) ContentType() string { + return f.contentType +} + +// Size returns the size of a file +func (f File) Size() int64 { + return f.size +} + +// Mode will return the mode of a given file +func (f File) Mode() os.FileMode { + // TODO check webdav perms + if f.isdir { + return 0775 | os.ModeDir + } + + return 0664 +} + +// ModTime returns the modified time of a file +func (f File) ModTime() time.Time { + return f.modified +} + +// ETag returns the ETag of a file +func (f File) ETag() string { + return f.etag +} + +// IsDir let us see if a given file is a directory or not +func (f File) IsDir() bool { + return f.isdir +} + +// Sys ???? +func (f File) Sys() interface{} { + return nil +} + +// String lets us see file information +func (f File) String() string { + if f.isdir { + return fmt.Sprintf("Dir : '%s' - '%s'", f.path, f.name) + } + + return fmt.Sprintf("File: '%s' SIZE: %d MODIFIED: %s ETAG: %s CTYPE: %s", f.path, f.size, f.modified.String(), f.etag, f.contentType) +} diff --git a/pkg/gowebdav/netrc.go b/pkg/gowebdav/netrc.go new file mode 100644 index 0000000000000000000000000000000000000000..df479b52cfd0f12edb416473cb0fb85dd96f004d --- /dev/null +++ b/pkg/gowebdav/netrc.go @@ -0,0 +1,54 @@ +package gowebdav + +import ( + "bufio" + "fmt" + "net/url" + "os" + "regexp" + "strings" +) + +func parseLine(s string) (login, pass string) { + fields := strings.Fields(s) + for i, f := range fields { + if f == "login" { + login = fields[i+1] + } + if f == "password" { + pass = fields[i+1] + } + } + return login, pass +} + +// ReadConfig reads login and password configuration from ~/.netrc +// machine foo.com login username password 123456 +func ReadConfig(uri, netrc string) (string, string) { + u, err := url.Parse(uri) + if err != nil { + return "", "" + } + + file, err := os.Open(netrc) + if err != nil { + return "", "" + } + defer file.Close() + + re := fmt.Sprintf(`^.*machine %s.*$`, u.Host) + scanner := bufio.NewScanner(file) + for scanner.Scan() { + s := scanner.Text() + + matched, err := regexp.MatchString(re, s) + if err != nil { + return "", "" + } + if matched { + return parseLine(s) + } + } + + return "", "" +} diff --git a/pkg/gowebdav/requests.go b/pkg/gowebdav/requests.go new file mode 100644 index 0000000000000000000000000000000000000000..d6237767d38160f8bab3dbcb6963473e17316817 --- /dev/null +++ b/pkg/gowebdav/requests.go @@ -0,0 +1,218 @@ +package gowebdav + +import ( + "bytes" + "fmt" + "io" + "net/http" + "path" + "strings" +) + +func (c *Client) req(method, path string, body io.Reader, intercept func(*http.Request)) (req *http.Response, err error) { + var r *http.Request + var retryBuf io.Reader + canRetry := true + if body != nil { + // If the authorization fails, we will need to restart reading + // from the passed body stream. + // When body is seekable, use seek to reset the streams + // cursor to the start. + // Otherwise, copy the stream into a buffer while uploading + // and use the buffers content on retry. + if sk, ok := body.(io.Seeker); ok { + if _, err = sk.Seek(0, io.SeekStart); err != nil { + return + } + retryBuf = body + } else if method == http.MethodPut { + canRetry = false + } else { + buff := &bytes.Buffer{} + retryBuf = buff + body = io.TeeReader(body, buff) + } + r, err = http.NewRequest(method, PathEscape(Join(c.root, path)), body) + } else { + r, err = http.NewRequest(method, PathEscape(Join(c.root, path)), nil) + } + + if err != nil { + return nil, err + } + + for k, vals := range c.headers { + for _, v := range vals { + r.Header.Add(k, v) + } + } + + // make sure we read 'c.auth' only once since it will be substituted below + // and that is unsafe to do when multiple goroutines are running at the same time. + c.authMutex.Lock() + auth := c.auth + c.authMutex.Unlock() + + auth.Authorize(r, method, path) + + if intercept != nil { + intercept(r) + } + + if c.interceptor != nil { + c.interceptor(method, r) + } + + rs, err := c.c.Do(r) + if err != nil { + return nil, err + } + + if rs.StatusCode == 401 && auth.Type() == "NoAuth" { + wwwAuthenticateHeader := strings.ToLower(rs.Header.Get("Www-Authenticate")) + + if strings.Index(wwwAuthenticateHeader, "digest") > -1 { + c.authMutex.Lock() + c.auth = &DigestAuth{auth.User(), auth.Pass(), digestParts(rs)} + c.authMutex.Unlock() + } else if strings.Index(wwwAuthenticateHeader, "basic") > -1 { + c.authMutex.Lock() + c.auth = &BasicAuth{auth.User(), auth.Pass()} + c.authMutex.Unlock() + } else { + return rs, newPathError("Authorize", c.root, rs.StatusCode) + } + + // retryBuf will be nil if body was nil initially so no check + // for body == nil is required here. + if canRetry { + return c.req(method, path, retryBuf, intercept) + } + } else if rs.StatusCode == 401 { + return rs, newPathError("Authorize", c.root, rs.StatusCode) + } + + return rs, err +} + +func (c *Client) mkcol(path string) (status int, err error) { + rs, err := c.req("MKCOL", path, nil, nil) + if err != nil { + return + } + defer rs.Body.Close() + + status = rs.StatusCode + if status == 405 { + status = 201 + } + + return +} + +func (c *Client) options(path string) (*http.Response, error) { + return c.req("OPTIONS", path, nil, func(rq *http.Request) { + rq.Header.Add("Depth", "0") + }) +} + +func (c *Client) propfind(path string, self bool, body string, resp interface{}, parse func(resp interface{}) error) error { + rs, err := c.req("PROPFIND", path, strings.NewReader(body), func(rq *http.Request) { + if self { + rq.Header.Add("Depth", "0") + } else { + rq.Header.Add("Depth", "1") + } + rq.Header.Add("Content-Type", "application/xml;charset=UTF-8") + rq.Header.Add("Accept", "application/xml,text/xml") + rq.Header.Add("Accept-Charset", "utf-8") + // TODO add support for 'gzip,deflate;q=0.8,q=0.7' + rq.Header.Add("Accept-Encoding", "") + }) + if err != nil { + return err + } + defer rs.Body.Close() + + if rs.StatusCode != 207 { + return newPathError("PROPFIND", path, rs.StatusCode) + } + + return parseXML(rs.Body, resp, parse) +} + +func (c *Client) doCopyMove( + method string, + oldpath string, + newpath string, + overwrite bool, +) ( + status int, + r io.ReadCloser, + err error, +) { + rs, err := c.req(method, oldpath, nil, func(rq *http.Request) { + rq.Header.Add("Destination", PathEscape(Join(c.root, newpath))) + if overwrite { + rq.Header.Add("Overwrite", "T") + } else { + rq.Header.Add("Overwrite", "F") + } + }) + if err != nil { + return + } + status = rs.StatusCode + r = rs.Body + return +} + +func (c *Client) copymove(method string, oldpath string, newpath string, overwrite bool) (err error) { + s, data, err := c.doCopyMove(method, oldpath, newpath, overwrite) + if err != nil { + return + } + if data != nil { + defer data.Close() + } + + switch s { + case 201, 204: + return nil + + case 207: + // TODO handle multistat errors, worst case ... + log(fmt.Sprintf(" TODO handle %s - %s multistatus result %s", method, oldpath, String(data))) + + case 409: + err := c.createParentCollection(newpath) + if err != nil { + return err + } + + return c.copymove(method, oldpath, newpath, overwrite) + } + + return newPathError(method, oldpath, s) +} + +func (c *Client) put(path string, stream io.Reader, callback func(r *http.Request)) (status int, err error) { + rs, err := c.req(http.MethodPut, path, stream, callback) + if err != nil { + return + } + defer rs.Body.Close() + //all, _ := io.ReadAll(rs.Body) + //logrus.Debugln("put res: ", string(all)) + status = rs.StatusCode + return +} + +func (c *Client) createParentCollection(itemPath string) (err error) { + parentPath := path.Dir(itemPath) + if parentPath == "." || parentPath == "/" { + return nil + } + + return c.MkdirAll(parentPath, 0755) +} diff --git a/pkg/gowebdav/utils.go b/pkg/gowebdav/utils.go new file mode 100644 index 0000000000000000000000000000000000000000..c7a65ad50ae4d91347b40d97fe69cbdb2813f8fb --- /dev/null +++ b/pkg/gowebdav/utils.go @@ -0,0 +1,118 @@ +package gowebdav + +import ( + "bytes" + "encoding/xml" + "fmt" + "io" + "net/url" + "strconv" + "strings" + "time" +) + +func log(msg interface{}) { + fmt.Println(msg) +} + +// PathEscape escapes all segments of a given path +func PathEscape(path string) string { + s := strings.Split(path, "/") + for i, e := range s { + s[i] = url.PathEscape(e) + } + return strings.Join(s, "/") +} + +// FixSlash appends a trailing / to our string +func FixSlash(s string) string { + if !strings.HasSuffix(s, "/") { + s += "/" + } + return s +} + +// FixSlashes appends and prepends a / if they are missing +func FixSlashes(s string) string { + if !strings.HasPrefix(s, "/") { + s = "/" + s + } + + return FixSlash(s) +} + +// Join joins two paths +func Join(path0 string, path1 string) string { + return strings.TrimSuffix(path0, "/") + "/" + strings.TrimPrefix(path1, "/") +} + +// String pulls a string out of our io.Reader +func String(r io.Reader) string { + buf := new(bytes.Buffer) + // TODO - make String return an error as well + _, _ = buf.ReadFrom(r) + return buf.String() +} + +func parseUint(s *string) uint { + if n, e := strconv.ParseUint(*s, 10, 32); e == nil { + return uint(n) + } + return 0 +} + +func parseInt64(s *string) int64 { + if n, e := strconv.ParseInt(*s, 10, 64); e == nil { + return n + } + return 0 +} + +func parseModified(s *string) time.Time { + if t, e := time.Parse(time.RFC1123, *s); e == nil { + return t + } + return time.Unix(0, 0) +} + +func parseXML(data io.Reader, resp interface{}, parse func(resp interface{}) error) error { + decoder := xml.NewDecoder(data) + for t, _ := decoder.Token(); t != nil; t, _ = decoder.Token() { + switch se := t.(type) { + case xml.StartElement: + if se.Name.Local == "response" { + if e := decoder.DecodeElement(resp, &se); e == nil { + if err := parse(resp); err != nil { + return err + } + } + } + } + } + return nil +} + +// limitedReadCloser wraps a io.ReadCloser and limits the number of bytes that can be read from it. +type limitedReadCloser struct { + rc io.ReadCloser + remaining int +} + +func (l *limitedReadCloser) Read(buf []byte) (int, error) { + if l.remaining <= 0 { + return 0, io.EOF + } + + if len(buf) > l.remaining { + buf = buf[0:l.remaining] + } + + n, err := l.rc.Read(buf) + l.remaining -= n + + return n, err +} + +func (l *limitedReadCloser) Close() error { + return l.rc.Close() +} diff --git a/pkg/gowebdav/utils_test.go b/pkg/gowebdav/utils_test.go new file mode 100644 index 0000000000000000000000000000000000000000..db7b0229ca17ec6e3bb69a8261c77ce86ede3704 --- /dev/null +++ b/pkg/gowebdav/utils_test.go @@ -0,0 +1,67 @@ +package gowebdav + +import ( + "fmt" + "net/url" + "testing" +) + +func TestJoin(t *testing.T) { + eq(t, "/", "", "") + eq(t, "/", "/", "/") + eq(t, "/foo", "", "/foo") + eq(t, "foo/foo", "foo/", "/foo") + eq(t, "foo/foo", "foo/", "foo") +} + +func eq(t *testing.T, expected string, s0 string, s1 string) { + s := Join(s0, s1) + if s != expected { + t.Error("For", "'"+s0+"','"+s1+"'", "expected", "'"+expected+"'", "got", "'"+s+"'") + } +} + +func ExamplePathEscape() { + fmt.Println(PathEscape("")) + fmt.Println(PathEscape("/")) + fmt.Println(PathEscape("/web")) + fmt.Println(PathEscape("/web/")) + fmt.Println(PathEscape("/w e b/d a v/s%u&c#k:s/")) + + // Output: + // + // / + // /web + // /web/ + // /w%20e%20b/d%20a%20v/s%25u&c%23k:s/ +} + +func TestEscapeURL(t *testing.T) { + ex := "https://foo.com/w%20e%20b/d%20a%20v/s%25u&c%23k:s/" + u, _ := url.Parse("https://foo.com" + PathEscape("/w e b/d a v/s%u&c#k:s/")) + if ex != u.String() { + t.Error("expected: " + ex + " got: " + u.String()) + } +} + +func TestFixSlashes(t *testing.T) { + expected := "/" + + if got := FixSlashes(""); got != expected { + t.Errorf("expected: %q, got: %q", expected, got) + } + + expected = "/path/" + + if got := FixSlashes("path"); got != expected { + t.Errorf("expected: %q, got: %q", expected, got) + } + + if got := FixSlashes("/path"); got != expected { + t.Errorf("expected: %q, got: %q", expected, got) + } + + if got := FixSlashes("path/"); got != expected { + t.Errorf("expected: %q, got: %q", expected, got) + } +} diff --git a/pkg/http_range/range.go b/pkg/http_range/range.go new file mode 100644 index 0000000000000000000000000000000000000000..5edd210d177d6e447b84bc3c5972ea04c743affd --- /dev/null +++ b/pkg/http_range/range.go @@ -0,0 +1,154 @@ +// Package http_range implements http range parsing. +package http_range + +import ( + "errors" + "fmt" + "net/http" + "net/textproto" + "strconv" + "strings" +) + +// Range specifies the byte range to be sent to the client. +type Range struct { + Start int64 + Length int64 // limit of bytes to read, -1 for unlimited +} + +// ContentRange returns Content-Range header value. +func (r Range) ContentRange(size int64) string { + return fmt.Sprintf("bytes %d-%d/%d", r.Start, r.Start+r.Length-1, size) +} + +var ( + // ErrNoOverlap is returned by ParseRange if first-byte-pos of + // all the byte-range-spec values is greater than the content size. + ErrNoOverlap = errors.New("invalid range: failed to overlap") + + // ErrInvalid is returned by ParseRange on invalid input. + ErrInvalid = errors.New("invalid range") +) + +// ParseRange parses a Range header string as per RFC 7233. +// ErrNoOverlap is returned if none of the ranges overlap. +// ErrInvalid is returned if s is invalid range. +func ParseRange(s string, size int64) ([]Range, error) { // nolint:gocognit + if s == "" { + return nil, nil // header not present + } + const b = "bytes=" + if !strings.HasPrefix(s, b) { + return nil, ErrInvalid + } + var ranges []Range + noOverlap := false + for _, ra := range strings.Split(s[len(b):], ",") { + ra = textproto.TrimString(ra) + if ra == "" { + continue + } + i := strings.Index(ra, "-") + if i < 0 { + return nil, ErrInvalid + } + start, end := textproto.TrimString(ra[:i]), textproto.TrimString(ra[i+1:]) + var r Range + if start == "" { + // If no start is specified, end specifies the + // range start relative to the end of the file, + // and we are dealing with + // which has to be a non-negative integer as per + // RFC 7233 Section 2.1 "Byte-Ranges". + if end == "" || end[0] == '-' { + return nil, ErrInvalid + } + i, err := strconv.ParseInt(end, 10, 64) + if i < 0 || err != nil { + return nil, ErrInvalid + } + if i > size { + i = size + } + r.Start = size - i + r.Length = size - r.Start + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i < 0 { + return nil, ErrInvalid + } + if i >= size { + // If the range begins after the size of the content, + // then it does not overlap. + noOverlap = true + continue + } + r.Start = i + if end == "" { + // If no end is specified, range extends to end of the file. + r.Length = size - r.Start + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || r.Start > i { + return nil, ErrInvalid + } + if i >= size { + i = size - 1 + } + r.Length = i - r.Start + 1 + } + } + ranges = append(ranges, r) + } + if noOverlap && len(ranges) == 0 { + // The specified ranges did not overlap with the content. + return nil, ErrNoOverlap + } + return ranges, nil +} + +// ParseContentRange this function parse content-range in http response +func ParseContentRange(s string) (start, end int64, err error) { + if s == "" { + return 0, 0, ErrInvalid + } + const b = "bytes " + if !strings.HasPrefix(s, b) { + return 0, 0, ErrInvalid + } + p1 := strings.Index(s, "-") + p2 := strings.Index(s, "/") + if p1 < 0 || p2 < 0 { + return 0, 0, ErrInvalid + } + startStr, endStr := textproto.TrimString(s[len(b):p1]), textproto.TrimString(s[p1+1:p2]) + start, startErr := strconv.ParseInt(startStr, 10, 64) + end, endErr := strconv.ParseInt(endStr, 10, 64) + + return start, end, errors.Join(startErr, endErr) +} + +func (r Range) MimeHeader(contentType string, size int64) textproto.MIMEHeader { + return textproto.MIMEHeader{ + "Content-Range": {r.ContentRange(size)}, + "Content-Type": {contentType}, + } +} + +// ApplyRangeToHttpHeader for http request header +func ApplyRangeToHttpHeader(p Range, headerRef http.Header) http.Header { + header := headerRef + if header == nil { + header = http.Header{} + } + if p.Start == 0 && p.Length < 0 { + header.Del("Range") + } else { + end := "" + if p.Length >= 0 { + end = strconv.FormatInt(p.Start+p.Length-1, 10) + } + header.Set("Range", fmt.Sprintf("bytes=%v-%v", p.Start, end)) + } + return header +} diff --git a/pkg/mq/mq.go b/pkg/mq/mq.go new file mode 100644 index 0000000000000000000000000000000000000000..35f5a159de42cd7693f1b17b96db5475e5251b9a --- /dev/null +++ b/pkg/mq/mq.go @@ -0,0 +1,61 @@ +package mq + +import ( + "sync" + + "github.com/alist-org/alist/v3/pkg/generic" +) + +type Message[T any] struct { + Content T +} + +type BasicConsumer[T any] func(Message[T]) +type AllConsumer[T any] func([]Message[T]) + +type MQ[T any] interface { + Publish(Message[T]) + Consume(BasicConsumer[T]) + ConsumeAll(AllConsumer[T]) + Clear() + Len() int +} + +type inMemoryMQ[T any] struct { + queue generic.Queue[Message[T]] + sync.Mutex +} + +func NewInMemoryMQ[T any]() MQ[T] { + return &inMemoryMQ[T]{queue: *generic.NewQueue[Message[T]]()} +} + +func (mq *inMemoryMQ[T]) Publish(msg Message[T]) { + mq.Lock() + defer mq.Unlock() + mq.queue.Push(msg) +} + +func (mq *inMemoryMQ[T]) Consume(consumer BasicConsumer[T]) { + mq.Lock() + defer mq.Unlock() + for !mq.queue.IsEmpty() { + consumer(mq.queue.Pop()) + } +} + +func (mq *inMemoryMQ[T]) ConsumeAll(consumer AllConsumer[T]) { + mq.Lock() + defer mq.Unlock() + consumer(mq.queue.PopAll()) +} + +func (mq *inMemoryMQ[T]) Clear() { + mq.Lock() + defer mq.Unlock() + mq.queue.Clear() +} + +func (mq *inMemoryMQ[T]) Len() int { + return mq.queue.Len() +} diff --git a/pkg/qbittorrent/client.go b/pkg/qbittorrent/client.go new file mode 100644 index 0000000000000000000000000000000000000000..ec3f7e7b00c5262aadaa1d828457dc3e16a79b88 --- /dev/null +++ b/pkg/qbittorrent/client.go @@ -0,0 +1,366 @@ +package qbittorrent + +import ( + "bytes" + "errors" + "io" + "mime/multipart" + "net/http" + "net/http/cookiejar" + "net/url" + + "github.com/alist-org/alist/v3/pkg/utils" +) + +type Client interface { + AddFromLink(link string, savePath string, id string) error + GetInfo(id string) (TorrentInfo, error) + GetFiles(id string) ([]FileInfo, error) + Delete(id string, deleteFiles bool) error +} + +type client struct { + url *url.URL + client http.Client + Client +} + +func New(webuiUrl string) (Client, error) { + u, err := url.Parse(webuiUrl) + if err != nil { + return nil, err + } + + jar, err := cookiejar.New(nil) + if err != nil { + return nil, err + } + var c = &client{ + url: u, + client: http.Client{Jar: jar}, + } + + err = c.checkAuthorization() + if err != nil { + return nil, err + } + return c, nil +} + +func (c *client) checkAuthorization() error { + // check authorization + if c.authorized() { + return nil + } + + // check authorization after logging in + err := c.login() + if err != nil { + return err + } + if c.authorized() { + return nil + } + return errors.New("unauthorized qbittorrent url") +} + +func (c *client) authorized() bool { + resp, err := c.post("/api/v2/app/version", nil) + if err != nil { + return false + } + return resp.StatusCode == 200 // the status code will be 403 if not authorized +} + +func (c *client) login() error { + // prepare HTTP request + v := url.Values{} + v.Set("username", c.url.User.Username()) + passwd, _ := c.url.User.Password() + v.Set("password", passwd) + resp, err := c.post("/api/v2/auth/login", v) + if err != nil { + return err + } + + // check result + body := make([]byte, 2) + _, err = resp.Body.Read(body) + if err != nil { + return err + } + if string(body) != "Ok" { + return errors.New("failed to login into qBittorrent webui with url: " + c.url.String()) + } + return nil +} + +func (c *client) post(path string, data url.Values) (*http.Response, error) { + u := c.url.JoinPath(path) + u.User = nil // remove userinfo for requests + + req, err := http.NewRequest("POST", u.String(), bytes.NewReader([]byte(data.Encode()))) + if err != nil { + return nil, err + } + if data != nil { + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + } + + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + if resp.Cookies() != nil { + c.client.Jar.SetCookies(u, resp.Cookies()) + } + return resp, nil +} + +func (c *client) AddFromLink(link string, savePath string, id string) error { + err := c.checkAuthorization() + if err != nil { + return err + } + + buf := new(bytes.Buffer) + writer := multipart.NewWriter(buf) + + addField := func(name string, value string) { + if err != nil { + return + } + err = writer.WriteField(name, value) + } + addField("urls", link) + addField("savepath", savePath) + addField("tags", "alist-"+id) + addField("autoTMM", "false") + if err != nil { + return err + } + + err = writer.Close() + if err != nil { + return err + } + + u := c.url.JoinPath("/api/v2/torrents/add") + u.User = nil // remove userinfo for requests + req, err := http.NewRequest("POST", u.String(), buf) + if err != nil { + return err + } + req.Header.Add("Content-Type", writer.FormDataContentType()) + + resp, err := c.client.Do(req) + if err != nil { + return err + } + + // check result + body := make([]byte, 2) + _, err = resp.Body.Read(body) + if err != nil { + return err + } + if resp.StatusCode != 200 || string(body) != "Ok" { + return errors.New("failed to add qBittorrent task: " + link) + } + return nil +} + +type TorrentStatus string + +const ( + ERROR TorrentStatus = "error" + MISSINGFILES TorrentStatus = "missingFiles" + UPLOADING TorrentStatus = "uploading" + PAUSEDUP TorrentStatus = "pausedUP" + QUEUEDUP TorrentStatus = "queuedUP" + STALLEDUP TorrentStatus = "stalledUP" + CHECKINGUP TorrentStatus = "checkingUP" + FORCEDUP TorrentStatus = "forcedUP" + ALLOCATING TorrentStatus = "allocating" + DOWNLOADING TorrentStatus = "downloading" + METADL TorrentStatus = "metaDL" + PAUSEDDL TorrentStatus = "pausedDL" + QUEUEDDL TorrentStatus = "queuedDL" + STALLEDDL TorrentStatus = "stalledDL" + CHECKINGDL TorrentStatus = "checkingDL" + FORCEDDL TorrentStatus = "forcedDL" + CHECKINGRESUMEDATA TorrentStatus = "checkingResumeData" + MOVING TorrentStatus = "moving" + UNKNOWN TorrentStatus = "unknown" +) + +// https://github.com/DGuang21/PTGo/blob/main/app/client/client_distributer.go +type TorrentInfo struct { + AddedOn int `json:"added_on"` // 将 torrent 添加到客户端的时间(Unix Epoch) + AmountLeft int64 `json:"amount_left"` // 剩余大小(字节) + AutoTmm bool `json:"auto_tmm"` // 此 torrent 是否由 Automatic Torrent Management 管理 + Availability float64 `json:"availability"` // 当前百分比 + Category string `json:"category"` // + Completed int64 `json:"completed"` // 完成的传输数据量(字节) + CompletionOn int `json:"completion_on"` // Torrent 完成的时间(Unix Epoch) + ContentPath string `json:"content_path"` // torrent 内容的绝对路径(多文件 torrent 的根路径,单文件 torrent 的绝对文件路径) + DlLimit int `json:"dl_limit"` // Torrent 下载速度限制(字节/秒) + Dlspeed int `json:"dlspeed"` // Torrent 下载速度(字节/秒) + Downloaded int64 `json:"downloaded"` // 已经下载大小 + DownloadedSession int64 `json:"downloaded_session"` // 此会话下载的数据量 + Eta int `json:"eta"` // + FLPiecePrio bool `json:"f_l_piece_prio"` // 如果第一个最后一块被优先考虑,则为true + ForceStart bool `json:"force_start"` // 如果为此 torrent 启用了强制启动,则为true + Hash string `json:"hash"` // + LastActivity int `json:"last_activity"` // 上次活跃的时间(Unix Epoch) + MagnetURI string `json:"magnet_uri"` // 与此 torrent 对应的 Magnet URI + MaxRatio float64 `json:"max_ratio"` // 种子/上传停止种子前的最大共享比率 + MaxSeedingTime int `json:"max_seeding_time"` // 停止种子种子前的最长种子时间(秒) + Name string `json:"name"` // + NumComplete int `json:"num_complete"` // + NumIncomplete int `json:"num_incomplete"` // + NumLeechs int `json:"num_leechs"` // 连接到的 leechers 的数量 + NumSeeds int `json:"num_seeds"` // 连接到的种子数 + Priority int `json:"priority"` // 速度优先。如果队列被禁用或 torrent 处于种子模式,则返回 -1 + Progress float64 `json:"progress"` // 进度 + Ratio float64 `json:"ratio"` // Torrent 共享比率 + RatioLimit int `json:"ratio_limit"` // + SavePath string `json:"save_path"` + SeedingTime int `json:"seeding_time"` // Torrent 完成用时(秒) + SeedingTimeLimit int `json:"seeding_time_limit"` // max_seeding_time + SeenComplete int `json:"seen_complete"` // 上次 torrent 完成的时间 + SeqDl bool `json:"seq_dl"` // 如果启用顺序下载,则为true + Size int64 `json:"size"` // + State TorrentStatus `json:"state"` // 参见https://github.com/qbittorrent/qBittorrent/wiki/WebUI-API-(qBittorrent-4.1)#get-torrent-list + SuperSeeding bool `json:"super_seeding"` // 如果启用超级播种,则为true + Tags string `json:"tags"` // Torrent 的逗号连接标签列表 + TimeActive int `json:"time_active"` // 总活动时间(秒) + TotalSize int64 `json:"total_size"` // 此 torrent 中所有文件的总大小(字节)(包括未选择的文件) + Tracker string `json:"tracker"` // 第一个具有工作状态的tracker。如果没有tracker在工作,则返回空字符串。 + TrackersCount int `json:"trackers_count"` // + UpLimit int `json:"up_limit"` // 上传限制 + Uploaded int64 `json:"uploaded"` // 累计上传 + UploadedSession int64 `json:"uploaded_session"` // 当前session累计上传 + Upspeed int `json:"upspeed"` // 上传速度(字节/秒) +} + +type InfoNotFoundError struct { + Id string + Err error +} + +func (i InfoNotFoundError) Error() string { + return "there should be exactly one task with tag \"alist-" + i.Id + "\"" +} + +func NewInfoNotFoundError(id string) InfoNotFoundError { + return InfoNotFoundError{Id: id} +} + +func (c *client) GetInfo(id string) (TorrentInfo, error) { + var infos []TorrentInfo + + err := c.checkAuthorization() + if err != nil { + return TorrentInfo{}, err + } + + v := url.Values{} + v.Set("tag", "alist-"+id) + response, err := c.post("/api/v2/torrents/info", v) + if err != nil { + return TorrentInfo{}, err + } + + body, err := io.ReadAll(response.Body) + if err != nil { + return TorrentInfo{}, err + } + err = utils.Json.Unmarshal(body, &infos) + if err != nil { + return TorrentInfo{}, err + } + if len(infos) != 1 { + return TorrentInfo{}, NewInfoNotFoundError(id) + } + return infos[0], nil +} + +type FileInfo struct { + Index int `json:"index"` + Name string `json:"name"` + Size int64 `json:"size"` + Progress float32 `json:"progress"` + Priority int `json:"priority"` + IsSeed bool `json:"is_seed"` + PieceRange []int `json:"piece_range"` + Availability float32 `json:"availability"` +} + +func (c *client) GetFiles(id string) ([]FileInfo, error) { + var infos []FileInfo + + err := c.checkAuthorization() + if err != nil { + return []FileInfo{}, err + } + + tInfo, err := c.GetInfo(id) + if err != nil { + return []FileInfo{}, err + } + + v := url.Values{} + v.Set("hash", tInfo.Hash) + response, err := c.post("/api/v2/torrents/files", v) + if err != nil { + return []FileInfo{}, err + } + + body, err := io.ReadAll(response.Body) + if err != nil { + return []FileInfo{}, err + } + err = utils.Json.Unmarshal(body, &infos) + if err != nil { + return []FileInfo{}, err + } + return infos, nil +} + +func (c *client) Delete(id string, deleteFiles bool) error { + err := c.checkAuthorization() + if err != nil { + return err + } + + info, err := c.GetInfo(id) + if err != nil { + return err + } + v := url.Values{} + v.Set("hashes", info.Hash) + if deleteFiles { + v.Set("deleteFiles", "true") + } else { + v.Set("deleteFiles", "false") + } + response, err := c.post("/api/v2/torrents/delete", v) + if err != nil { + return err + } + if response.StatusCode != 200 { + return errors.New("failed to delete qbittorrent task") + } + + v = url.Values{} + v.Set("tags", "alist-"+id) + response, err = c.post("/api/v2/torrents/deleteTags", v) + if err != nil { + return err + } + if response.StatusCode != 200 { + return errors.New("failed to delete qbittorrent tag") + } + return nil +} diff --git a/pkg/sign/hmac.go b/pkg/sign/hmac.go new file mode 100644 index 0000000000000000000000000000000000000000..8d7f736b1f305c9c16a4674e47fae10116d5cf54 --- /dev/null +++ b/pkg/sign/hmac.go @@ -0,0 +1,52 @@ +package sign + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "io" + "strconv" + "strings" + "time" +) + +type HMACSign struct { + SecretKey []byte +} + +func (s HMACSign) Sign(data string, expire int64) string { + h := hmac.New(sha256.New, s.SecretKey) + expireTimeStamp := strconv.FormatInt(expire, 10) + _, err := io.WriteString(h, data+":"+expireTimeStamp) + if err != nil { + return "" + } + + return base64.URLEncoding.EncodeToString(h.Sum(nil)) + ":" + expireTimeStamp +} + +func (s HMACSign) Verify(data, sign string) error { + signSlice := strings.Split(sign, ":") + // check whether contains expire time + if signSlice[len(signSlice)-1] == "" { + return ErrExpireMissing + } + // check whether expire time is expired + expires, err := strconv.ParseInt(signSlice[len(signSlice)-1], 10, 64) + if err != nil { + return ErrExpireInvalid + } + // if expire time is expired, return error + if expires < time.Now().Unix() && expires != 0 { + return ErrSignExpired + } + // verify sign + if s.Sign(data, expires) != sign { + return ErrSignInvalid + } + return nil +} + +func NewHMACSign(secret []byte) Sign { + return HMACSign{SecretKey: secret} +} diff --git a/pkg/sign/sign.go b/pkg/sign/sign.go new file mode 100644 index 0000000000000000000000000000000000000000..2a28667728012e536b1dc79a266ae16a4e1ad2a1 --- /dev/null +++ b/pkg/sign/sign.go @@ -0,0 +1,15 @@ +package sign + +import "errors" + +type Sign interface { + Sign(data string, expire int64) string + Verify(data, sign string) error +} + +var ( + ErrSignExpired = errors.New("sign expired") + ErrSignInvalid = errors.New("sign invalid") + ErrExpireInvalid = errors.New("expire invalid") + ErrExpireMissing = errors.New("expire missing") +) diff --git a/pkg/singleflight/signleflight_test.go b/pkg/singleflight/signleflight_test.go new file mode 100644 index 0000000000000000000000000000000000000000..34250299776f0f2293fb5ca798948f11cae6ab50 --- /dev/null +++ b/pkg/singleflight/signleflight_test.go @@ -0,0 +1,320 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package singleflight + +import ( + "bytes" + "errors" + "fmt" + "os" + "os/exec" + "runtime" + "runtime/debug" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestDo(t *testing.T) { + var g Group[string] + v, err, _ := g.Do("key", func() (string, error) { + return "bar", nil + }) + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +func TestDoErr(t *testing.T) { + var g Group[any] + someErr := errors.New("Some error") + v, err, _ := g.Do("key", func() (any, error) { + return nil, someErr + }) + if err != someErr { + t.Errorf("Do error = %v; want someErr %v", err, someErr) + } + if v != nil { + t.Errorf("unexpected non-nil value %#v", v) + } +} + +func TestDoDupSuppress(t *testing.T) { + var g Group[string] + var wg1, wg2 sync.WaitGroup + c := make(chan string, 1) + var calls int32 + fn := func() (string, error) { + if atomic.AddInt32(&calls, 1) == 1 { + // First invocation. + wg1.Done() + } + v := <-c + c <- v // pump; make available for any future calls + + time.Sleep(10 * time.Millisecond) // let more goroutines enter Do + + return v, nil + } + + const n = 10 + wg1.Add(1) + for i := 0; i < n; i++ { + wg1.Add(1) + wg2.Add(1) + go func() { + defer wg2.Done() + wg1.Done() + v, err, _ := g.Do("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + return + } + if v != "bar" { + t.Errorf("Do = %T %v; want %q", v, v, "bar") + } + }() + } + wg1.Wait() + // At least one goroutine is in fn now and all of them have at + // least reached the line before the Do. + c <- "bar" + wg2.Wait() + if got := atomic.LoadInt32(&calls); got <= 0 || got >= n { + t.Errorf("number of calls = %d; want over 0 and less than %d", got, n) + } +} + +// Test that singleflight behaves correctly after Forget called. +// See https://github.com/golang/go/issues/31420 +func TestForget(t *testing.T) { + var g Group[any] + + var ( + firstStarted = make(chan struct{}) + unblockFirst = make(chan struct{}) + firstFinished = make(chan struct{}) + ) + + go func() { + g.Do("key", func() (i any, e error) { + close(firstStarted) + <-unblockFirst + close(firstFinished) + return + }) + }() + <-firstStarted + g.Forget("key") + + unblockSecond := make(chan struct{}) + secondResult := g.DoChan("key", func() (i any, e error) { + <-unblockSecond + return 2, nil + }) + + close(unblockFirst) + <-firstFinished + + thirdResult := g.DoChan("key", func() (i any, e error) { + return 3, nil + }) + + close(unblockSecond) + <-secondResult + r := <-thirdResult + if r.Val != 2 { + t.Errorf("We should receive result produced by second call, expected: 2, got %d", r.Val) + } +} + +func TestDoChan(t *testing.T) { + var g Group[string] + ch := g.DoChan("key", func() (string, error) { + return "bar", nil + }) + + res := <-ch + v := res.Val + err := res.Err + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +// Test singleflight behaves correctly after Do panic. +// See https://github.com/golang/go/issues/41133 +func TestPanicDo(t *testing.T) { + var g Group[any] + fn := func() (any, error) { + panic("invalid memory address or nil pointer dereference") + } + + const n = 5 + waited := int32(n) + panicCount := int32(0) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + defer func() { + if err := recover(); err != nil { + t.Logf("Got panic: %v\n%s", err, debug.Stack()) + atomic.AddInt32(&panicCount, 1) + } + + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + + g.Do("key", fn) + }() + } + + select { + case <-done: + if panicCount != n { + t.Errorf("Expect %d panic, but got %d", n, panicCount) + } + case <-time.After(time.Second): + t.Fatalf("Do hangs") + } +} + +func TestGoexitDo(t *testing.T) { + var g Group[any] + fn := func() (any, error) { + runtime.Goexit() + return nil, nil + } + + const n = 5 + waited := int32(n) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + var err error + defer func() { + if err != nil { + t.Errorf("Error should be nil, but got: %v", err) + } + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + _, err, _ = g.Do("key", fn) + }() + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("Do hangs") + } +} + +func TestPanicDoChan(t *testing.T) { + if runtime.GOOS == "js" { + t.Skipf("js does not support exec") + } + + if os.Getenv("TEST_PANIC_DOCHAN") != "" { + defer func() { + recover() + }() + + g := new(Group[any]) + ch := g.DoChan("", func() (any, error) { + panic("Panicking in DoChan") + }) + <-ch + t.Fatalf("DoChan unexpectedly returned") + } + + t.Parallel() + + cmd := exec.Command(os.Args[0], "-test.run="+t.Name(), "-test.v") + cmd.Env = append(os.Environ(), "TEST_PANIC_DOCHAN=1") + out := new(bytes.Buffer) + cmd.Stdout = out + cmd.Stderr = out + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + err := cmd.Wait() + t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out) + if err == nil { + t.Errorf("Test subprocess passed; want a crash due to panic in DoChan") + } + if bytes.Contains(out.Bytes(), []byte("DoChan unexpectedly")) { + t.Errorf("Test subprocess failed with an unexpected failure mode.") + } + if !bytes.Contains(out.Bytes(), []byte("Panicking in DoChan")) { + t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in DoChan") + } +} + +func TestPanicDoSharedByDoChan(t *testing.T) { + if runtime.GOOS == "js" { + t.Skipf("js does not support exec") + } + + if os.Getenv("TEST_PANIC_DOCHAN") != "" { + blocked := make(chan struct{}) + unblock := make(chan struct{}) + + g := new(Group[any]) + go func() { + defer func() { + recover() + }() + g.Do("", func() (any, error) { + close(blocked) + <-unblock + panic("Panicking in Do") + }) + }() + + <-blocked + ch := g.DoChan("", func() (any, error) { + panic("DoChan unexpectedly executed callback") + }) + close(unblock) + <-ch + t.Fatalf("DoChan unexpectedly returned") + } + + t.Parallel() + + cmd := exec.Command(os.Args[0], "-test.run="+t.Name(), "-test.v") + cmd.Env = append(os.Environ(), "TEST_PANIC_DOCHAN=1") + out := new(bytes.Buffer) + cmd.Stdout = out + cmd.Stderr = out + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + err := cmd.Wait() + t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out) + if err == nil { + t.Errorf("Test subprocess passed; want a crash due to panic in Do shared by DoChan") + } + if bytes.Contains(out.Bytes(), []byte("DoChan unexpectedly")) { + t.Errorf("Test subprocess failed with an unexpected failure mode.") + } + if !bytes.Contains(out.Bytes(), []byte("Panicking in Do")) { + t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in Do") + } +} diff --git a/pkg/singleflight/singleflight.go b/pkg/singleflight/singleflight.go new file mode 100644 index 0000000000000000000000000000000000000000..dcd84a3b9d4997067f01eeed10347d38a2a8c8f8 --- /dev/null +++ b/pkg/singleflight/singleflight.go @@ -0,0 +1,212 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package singleflight provides a duplicate function call suppression +// mechanism. +package singleflight + +import ( + "bytes" + "errors" + "fmt" + "runtime" + "runtime/debug" + "sync" +) + +// errGoexit indicates the runtime.Goexit was called in +// the user given function. +var errGoexit = errors.New("runtime.Goexit was called") + +// A panicError is an arbitrary value recovered from a panic +// with the stack trace during the execution of given function. +type panicError struct { + value any + stack []byte +} + +// Error implements error interface. +func (p *panicError) Error() string { + return fmt.Sprintf("%v\n\n%s", p.value, p.stack) +} + +func newPanicError(v any) error { + stack := debug.Stack() + + // The first line of the stack trace is of the form "goroutine N [status]:" + // but by the time the panic reaches Do the goroutine may no longer exist + // and its status will have changed. Trim out the misleading line. + if line := bytes.IndexByte(stack[:], '\n'); line >= 0 { + stack = stack[line+1:] + } + return &panicError{value: v, stack: stack} +} + +// call is an in-flight or completed singleflight.Do call +type call[T any] struct { + wg sync.WaitGroup + + // These fields are written once before the WaitGroup is done + // and are only read after the WaitGroup is done. + val T + err error + + // forgotten indicates whether Forget was called with this call's key + // while the call was still in flight. + forgotten bool + + // These fields are read and written with the singleflight + // mutex held before the WaitGroup is done, and are read but + // not written after the WaitGroup is done. + dups int + chans []chan<- Result[T] +} + +// Group represents a class of work and forms a namespace in +// which units of work can be executed with duplicate suppression. +type Group[T any] struct { + mu sync.Mutex // protects m + m map[string]*call[T] // lazily initialized +} + +// Result holds the results of Do, so they can be passed +// on a channel. +type Result[T any] struct { + Val T + Err error + Shared bool +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +// The return value shared indicates whether v was given to multiple callers. +func (g *Group[T]) Do(key string, fn func() (T, error)) (v T, err error, shared bool) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call[T]) + } + if c, ok := g.m[key]; ok { + c.dups++ + g.mu.Unlock() + c.wg.Wait() + + if e, ok := c.err.(*panicError); ok { + panic(e) + } else if c.err == errGoexit { + runtime.Goexit() + } + return c.val, c.err, true + } + c := new(call[T]) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + g.doCall(c, key, fn) + return c.val, c.err, c.dups > 0 +} + +// DoChan is like Do but returns a channel that will receive the +// results when they are ready. +// +// The returned channel will not be closed. +func (g *Group[T]) DoChan(key string, fn func() (T, error)) <-chan Result[T] { + ch := make(chan Result[T], 1) + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call[T]) + } + if c, ok := g.m[key]; ok { + c.dups++ + c.chans = append(c.chans, ch) + g.mu.Unlock() + return ch + } + c := &call[T]{chans: []chan<- Result[T]{ch}} + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + go g.doCall(c, key, fn) + + return ch +} + +// doCall handles the single call for a key. +func (g *Group[T]) doCall(c *call[T], key string, fn func() (T, error)) { + normalReturn := false + recovered := false + + // use double-defer to distinguish panic from runtime.Goexit, + // more details see https://golang.org/cl/134395 + defer func() { + // the given function invoked runtime.Goexit + if !normalReturn && !recovered { + c.err = errGoexit + } + + c.wg.Done() + g.mu.Lock() + defer g.mu.Unlock() + if !c.forgotten { + delete(g.m, key) + } + + if e, ok := c.err.(*panicError); ok { + // In order to prevent the waiting channels from being blocked forever, + // needs to ensure that this panic cannot be recovered. + if len(c.chans) > 0 { + go panic(e) + select {} // Keep this goroutine around so that it will appear in the crash dump. + } else { + panic(e) + } + } else if c.err == errGoexit { + // Already in the process of goexit, no need to call again + } else { + // Normal return + for _, ch := range c.chans { + ch <- Result[T]{c.val, c.err, c.dups > 0} + } + } + }() + + func() { + defer func() { + if !normalReturn { + // Ideally, we would wait to take a stack trace until we've determined + // whether this is a panic or a runtime.Goexit. + // + // Unfortunately, the only way we can distinguish the two is to see + // whether the recover stopped the goroutine from terminating, and by + // the time we know that, the part of the stack trace relevant to the + // panic has been discarded. + if r := recover(); r != nil { + c.err = newPanicError(r) + } + } + }() + + c.val, c.err = fn() + normalReturn = true + }() + + if !normalReturn { + recovered = true + } +} + +// Forget tells the singleflight to forget about a key. Future calls +// to Do for this key will call the function rather than waiting for +// an earlier call to complete. +func (g *Group[T]) Forget(key string) { + g.mu.Lock() + if c, ok := g.m[key]; ok { + c.forgotten = true + } + delete(g.m, key) + g.mu.Unlock() +} diff --git a/pkg/tache/.gitattributes b/pkg/tache/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..dfe0770424b2a19faf507a501ebfc23be8f54e7b --- /dev/null +++ b/pkg/tache/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/pkg/tache/.gitignore b/pkg/tache/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..4bc60d76165b76fe383949ebd10e3fc106b77cfa --- /dev/null +++ b/pkg/tache/.gitignore @@ -0,0 +1,22 @@ +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work +.idea \ No newline at end of file diff --git a/pkg/tache/LICENSE b/pkg/tache/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4480662c09db621be5c96dd3ad4adac1c69d1ad9 --- /dev/null +++ b/pkg/tache/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Andy Hsu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/pkg/tache/README.md b/pkg/tache/README.md new file mode 100644 index 0000000000000000000000000000000000000000..41ac13bfd376e106f4336b148c9655a7360af042 --- /dev/null +++ b/pkg/tache/README.md @@ -0,0 +1,2 @@ +# tache + A task manager for golang diff --git a/pkg/tache/able.go b/pkg/tache/able.go new file mode 100644 index 0000000000000000000000000000000000000000..8ad92282b626121e28d1ccdc62294cf65994b54d --- /dev/null +++ b/pkg/tache/able.go @@ -0,0 +1,16 @@ +package tache + +// Persistable judge whether the task is persistable +type Persistable interface { + Persistable() bool +} + +// Recoverable judge whether the task is recoverable +type Recoverable interface { + Recoverable() bool +} + +// Retryable judge whether the task is retryable +type Retryable interface { + Retryable() bool +} diff --git a/pkg/tache/base.go b/pkg/tache/base.go new file mode 100644 index 0000000000000000000000000000000000000000..d42828924a1e8230a15e403ecaac6f9f960af991 --- /dev/null +++ b/pkg/tache/base.go @@ -0,0 +1,104 @@ +package tache + +import "context" + +// Base is the base struct for all tasks to implement TaskBase interface +type Base struct { + ID string `json:"id"` + State State `json:"state"` + Retry int `json:"retry"` + MaxRetry int `json:"max_retry"` + + progress float64 + size int64 + err error + ctx context.Context + cancel context.CancelFunc + persist func() +} + +func (b *Base) SetSize(size int64) { + b.size = size + b.Persist() +} + +func (b *Base) GetSize() int64 { + return b.size +} + +func (b *Base) SetProgress(progress float64) { + b.progress = progress + b.Persist() +} + +func (b *Base) GetProgress() float64 { + return b.progress +} + +func (b *Base) SetState(state State) { + b.State = state + b.Persist() +} + +func (b *Base) GetState() State { + return b.State +} + +func (b *Base) GetID() string { + return b.ID +} + +func (b *Base) SetID(id string) { + b.ID = id + b.Persist() +} + +func (b *Base) SetErr(err error) { + b.err = err + b.Persist() +} + +func (b *Base) GetErr() error { + return b.err +} + +func (b *Base) CtxDone() <-chan struct{} { + return b.Ctx().Done() +} + +func (b *Base) SetCtx(ctx context.Context) { + b.ctx = ctx +} + +func (b *Base) SetCancelFunc(cancelFunc context.CancelFunc) { + b.cancel = cancelFunc +} + +func (b *Base) GetRetry() (int, int) { + return b.Retry, b.MaxRetry +} + +func (b *Base) SetRetry(retry int, maxRetry int) { + b.Retry, b.MaxRetry = retry, maxRetry +} + +func (b *Base) Cancel() { + b.SetState(StateCanceling) + b.cancel() +} + +func (b *Base) Ctx() context.Context { + return b.ctx +} + +func (b *Base) Persist() { + if b.persist != nil { + b.persist() + } +} + +func (b *Base) SetPersist(persist func()) { + b.persist = persist +} + +var _ TaskBase = (*Base)(nil) diff --git a/pkg/tache/errors.go b/pkg/tache/errors.go new file mode 100644 index 0000000000000000000000000000000000000000..d14c068f738fc4221c97da8ce7d65f1c2db17ec2 --- /dev/null +++ b/pkg/tache/errors.go @@ -0,0 +1,40 @@ +package tache + +import "errors" + +// TacheError is a custom error type +type TacheError struct { + Msg string +} + +func (e *TacheError) Error() string { + return e.Msg +} + +// NewErr creates a new TacheError +func NewErr(msg string) error { + return &TacheError{Msg: msg} +} + +//var ( +// ErrTaskNotFound = NewErr("task not found") +// ErrTaskRunning = NewErr("task is running") +//) + +type unrecoverableError struct { + error +} + +func (e unrecoverableError) Unwrap() error { + return e.error +} + +// Unrecoverable wraps an error in `unrecoverableError` struct +func Unrecoverable(err error) error { + return unrecoverableError{err} +} + +// IsRecoverable checks if error is an instance of `unrecoverableError` +func IsRecoverable(err error) bool { + return !errors.Is(err, unrecoverableError{}) +} diff --git a/pkg/tache/examples/base/main.go b/pkg/tache/examples/base/main.go new file mode 100644 index 0000000000000000000000000000000000000000..22f11114481bfe885dffb5ae232a6cd771fc588b --- /dev/null +++ b/pkg/tache/examples/base/main.go @@ -0,0 +1 @@ +package base diff --git a/pkg/tache/hook.go b/pkg/tache/hook.go new file mode 100644 index 0000000000000000000000000000000000000000..c8a7427b42d747bbcf48079f55f96502f4b965f6 --- /dev/null +++ b/pkg/tache/hook.go @@ -0,0 +1,16 @@ +package tache + +// OnBeforeRetry is the interface for tasks that need to be executed before retrying +type OnBeforeRetry interface { + OnBeforeRetry() +} + +// OnSucceeded is the interface for tasks that need to be executed when they succeed +type OnSucceeded interface { + OnSucceeded() +} + +// OnFailed is the interface for tasks that need to be executed when they fail +type OnFailed interface { + OnFailed() +} diff --git a/pkg/tache/manager.go b/pkg/tache/manager.go new file mode 100644 index 0000000000000000000000000000000000000000..a8bbf939d033fd527ccfee387537f1a2cfc99d44 --- /dev/null +++ b/pkg/tache/manager.go @@ -0,0 +1,313 @@ +package tache + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "os" + "runtime" + "sync/atomic" + + "github.com/jaevor/go-nanoid" + + "github.com/xhofe/gsync" +) + +// Manager is the manager of all tasks +type Manager[T Task] struct { + tasks gsync.MapOf[string, T] + queue gsync.QueueOf[T] + workers *WorkerPool[T] + opts *Options + debouncePersist func() + running atomic.Bool + + idGenerator func() string + logger *slog.Logger +} + +// NewManager create a new manager +func NewManager[T Task](opts ...Option) *Manager[T] { + options := DefaultOptions() + for _, opt := range opts { + opt(options) + } + nanoID, err := nanoid.Standard(21) + if err != nil { + panic(err) + } + m := &Manager[T]{ + workers: NewWorkerPool[T](options.Works), + opts: options, + idGenerator: nanoID, + logger: options.Logger, + } + m.running.Store(options.Running) + if m.opts.PersistPath != "" || (m.opts.PersistReadFunction != nil && m.opts.PersistWriteFunction != nil) { + m.debouncePersist = func() { + _ = m.persist() + } + if m.opts.PersistDebounce != nil { + m.debouncePersist = newDebounce(func() { + _ = m.persist() + }, *m.opts.PersistDebounce) + } + err := m.recover() + if err != nil { + m.logger.Error("recover error", "error", err) + } + } else { + m.debouncePersist = func() {} + } + return m +} + +// Add a task to manager +func (m *Manager[T]) Add(task T) { + ctx, cancel := context.WithCancel(context.Background()) + task.SetCtx(ctx) + task.SetCancelFunc(cancel) + task.SetPersist(m.debouncePersist) + if task.GetID() == "" { + task.SetID(m.idGenerator()) + } + if _, maxRetry := task.GetRetry(); maxRetry == 0 { + task.SetRetry(0, m.opts.MaxRetry) + } + if sliceContains([]State{StateRunning}, task.GetState()) { + task.SetState(StatePending) + } + if sliceContains([]State{StateCanceling}, task.GetState()) { + task.SetState(StateCanceled) + task.SetErr(context.Canceled) + } + if task.GetState() == StateFailing { + task.SetState(StateFailed) + } + m.tasks.Store(task.GetID(), task) + if !sliceContains([]State{StateSucceeded, StateCanceled, StateErrored, StateFailed}, task.GetState()) { + m.queue.Push(task) + } + m.debouncePersist() + m.next() +} + +// get next task from queue and execute it +func (m *Manager[T]) next() { + // if manager is not running, return + if !m.running.Load() { + return + } + // if workers is full, return + worker := m.workers.Get() + if worker == nil { + return + } + m.logger.Debug("got worker", "id", worker.ID) + task, err := m.queue.Pop() + // if cannot get task, return + if err != nil { + m.workers.Put(worker) + return + } + m.logger.Debug("got task", "id", task.GetID()) + go func() { + defer func() { + if task.GetState() == StateWaitingRetry { + m.queue.Push(task) + } + m.workers.Put(worker) + m.next() + }() + if task.GetState() == StateCanceling { + task.SetState(StateCanceled) + task.SetErr(context.Canceled) + return + } + if m.opts.Timeout != nil { + ctx, cancel := context.WithTimeout(task.Ctx(), *m.opts.Timeout) + defer cancel() + task.SetCtx(ctx) + } + m.logger.Info("worker execute task", "worker", worker.ID, "task", task.GetID()) + worker.Execute(task) + }() +} + +// Wait wait all tasks done, just for test +func (m *Manager[T]) Wait() { + for { + tasks, running := m.queue.Len(), m.workers.working.Load() + if tasks == 0 && running == 0 { + return + } + runtime.Gosched() + } +} + +// persist all tasks +func (m *Manager[T]) persist() error { + if m.opts.PersistPath == "" && m.opts.PersistReadFunction == nil && m.opts.PersistWriteFunction == nil { + return nil + } + // serialize tasks + tasks := m.GetAll() + var toPersist []T + for _, task := range tasks { + // only persist task which is not persistable or persistable and need persist + if p, ok := Task(task).(Persistable); !ok || p.Persistable() { + toPersist = append(toPersist, task) + } + } + marshal, err := json.Marshal(toPersist) + if err != nil { + return err + } + if m.opts.PersistReadFunction != nil && m.opts.PersistWriteFunction != nil { + err = m.opts.PersistWriteFunction(marshal) + if err != nil { + return err + } + } + if m.opts.PersistPath != "" { + // write to file + err = os.WriteFile(m.opts.PersistPath, marshal, 0644) + if err != nil { + return err + } + } + return nil +} + +// recover all tasks +func (m *Manager[T]) recover() error { + var data []byte + var err error + if m.opts.PersistPath != "" { + // read from file + data, err = os.ReadFile(m.opts.PersistPath) + } else if m.opts.PersistReadFunction != nil && m.opts.PersistWriteFunction != nil { + data, err = m.opts.PersistReadFunction() + } else { + return nil + } + if err != nil { + return err + } + // deserialize tasks + var tasks []T + err = json.Unmarshal(data, &tasks) + if err != nil { + return err + } + // add tasks + for _, task := range tasks { + // only recover task which is not recoverable or recoverable and need recover + if r, ok := Task(task).(Recoverable); !ok || r.Recoverable() { + m.Add(task) + } else { + task.SetState(StateFailed) + task.SetErr(fmt.Errorf("the task is interrupted and cannot be recovered")) + m.tasks.Store(task.GetID(), task) + } + } + return nil +} + +// Cancel a task by ID +func (m *Manager[T]) Cancel(id string) { + if task, ok := m.tasks.Load(id); ok { + task.Cancel() + m.debouncePersist() + } +} + +// CancelAll cancel all tasks +func (m *Manager[T]) CancelAll() { + m.tasks.Range(func(key string, value T) bool { + value.Cancel() + return true + }) + m.debouncePersist() +} + +// GetAll get all tasks +func (m *Manager[T]) GetAll() []T { + var tasks []T + m.tasks.Range(func(key string, value T) bool { + tasks = append(tasks, value) + return true + }) + return tasks +} + +// GetByID get task by ID +func (m *Manager[T]) GetByID(id string) (T, bool) { + return m.tasks.Load(id) +} + +// GetByState get tasks by state +func (m *Manager[T]) GetByState(state ...State) []T { + var tasks []T + m.tasks.Range(func(key string, value T) bool { + if sliceContains(state, value.GetState()) { + tasks = append(tasks, value) + } + return true + }) + return tasks +} + +// Remove a task by ID +func (m *Manager[T]) Remove(id string) { + m.tasks.Delete(id) + m.debouncePersist() +} + +// RemoveAll remove all tasks +func (m *Manager[T]) RemoveAll() { + tasks := m.GetAll() + for _, task := range tasks { + m.Remove(task.GetID()) + } +} + +// RemoveByState remove tasks by state +func (m *Manager[T]) RemoveByState(state ...State) { + tasks := m.GetByState(state...) + for _, task := range tasks { + m.Remove(task.GetID()) + } +} + +// Retry a task by ID +func (m *Manager[T]) Retry(id string) { + if task, ok := m.tasks.Load(id); ok { + task.SetState(StateWaitingRetry) + task.SetErr(nil) + task.SetRetry(0, m.opts.MaxRetry) + m.queue.Push(task) + m.next() + m.debouncePersist() + } +} + +// RetryAllFailed retry all failed tasks +func (m *Manager[T]) RetryAllFailed() { + tasks := m.GetByState(StateFailed) + for _, task := range tasks { + m.Retry(task.GetID()) + } +} + +// Start manager +func (m *Manager[T]) Start() { + m.running.Store(true) + m.next() +} + +// Pause manager +func (m *Manager[T]) Pause() { + m.running.Store(false) +} diff --git a/pkg/tache/manager_test.go b/pkg/tache/manager_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7a91875b650ba4b334d7a417ed3528a866f10585 --- /dev/null +++ b/pkg/tache/manager_test.go @@ -0,0 +1,88 @@ +package tache_test + +import ( + "log/slog" + "os" + "sync/atomic" + "testing" + "time" + + "github.com/alist-org/alist/v3/pkg/tache" +) + +type TestTask struct { + tache.Base + Data string + do func(*TestTask) error +} + +func (t *TestTask) Run() error { + return t.do(t) +} + +func TestManager_Add(t *testing.T) { + tm := tache.NewManager[*TestTask]() + task := &TestTask{} + tm.Add(task) + t.Logf("%+v", task) +} + +func TestWithRetry(t *testing.T) { + tm := tache.NewManager[*TestTask](tache.WithMaxRetry(3), tache.WithWorks(1)) + var num atomic.Int64 + for i := int64(0); i < 10; i++ { + task := &TestTask{ + do: func(task *TestTask) error { + num.Add(1) + if num.Load() < i*3 { + return tache.NewErr("test") + } + return nil + }, + } + tm.Add(task) + } + tm.Wait() + tasks := tm.GetAll() + for _, task := range tasks { + t.Logf("%+v", task) + } +} + +func TestWithPersistPath(t *testing.T) { + tm := tache.NewManager[*TestTask](tache.WithPersistPath("./test.json")) + task := &TestTask{ + do: func(task *TestTask) error { + return nil + }, + Data: "haha", + } + tm.Add(task) + tm.Wait() + t.Logf("%+v", task) + time.Sleep(4 * time.Second) +} + +func TestMultiTasks(t *testing.T) { + tm := tache.NewManager[*TestTask](tache.WithWorks(3), tache.WithLogger(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + ReplaceAttr: nil, + })))) + var num atomic.Int64 + for i := 0; i < 100; i++ { + tm.Add(&TestTask{ + do: func(task *TestTask) error { + num.Add(1) + return nil + }, + }) + } + tm.Wait() + //time.Sleep(3 * time.Second) + if num.Load() != 100 { + t.Errorf("num error, num: %d", num.Load()) + } else { + t.Logf("num success, num: %d", num.Load()) + } +} diff --git a/pkg/tache/option.go b/pkg/tache/option.go new file mode 100644 index 0000000000000000000000000000000000000000..c4544f95faf7867f5e0e9e0d44b2d1eee424ddeb --- /dev/null +++ b/pkg/tache/option.go @@ -0,0 +1,97 @@ +package tache + +import ( + "log/slog" + "time" +) + +// Options is the options for manager +type Options struct { + Works int + MaxRetry int + Timeout *time.Duration + PersistPath string + PersistDebounce *time.Duration + Running bool + Logger *slog.Logger + PersistReadFunction func() ([]byte, error) + PersistWriteFunction func([]byte) error +} + +// DefaultOptions returns default options +func DefaultOptions() *Options { + persistDebounce := 3 * time.Second + return &Options{ + Works: 5, + //MaxRetry: 1, + PersistDebounce: &persistDebounce, + Running: true, + Logger: slog.Default(), + } +} + +// Option is the option for manager +type Option func(*Options) + +// WithOptions set options +func WithOptions(opts Options) Option { + return func(o *Options) { + *o = opts + } +} + +// WithWorks set works +func WithWorks(works int) Option { + return func(o *Options) { + o.Works = works + } +} + +// WithMaxRetry set retry +func WithMaxRetry(maxRetry int) Option { + return func(o *Options) { + o.MaxRetry = maxRetry + } +} + +// WithTimeout set timeout +func WithTimeout(timeout time.Duration) Option { + return func(o *Options) { + o.Timeout = &timeout + } +} + +// WithPersistPath set persist path +func WithPersistPath(path string) Option { + return func(o *Options) { + o.PersistPath = path + } +} + +func WithPersistFunction(r func() ([]byte, error), w func([]byte) error) Option { + return func(o *Options) { + o.PersistReadFunction = r + o.PersistWriteFunction = w + } +} + +// WithPersistDebounce set persist debounce +func WithPersistDebounce(debounce time.Duration) Option { + return func(o *Options) { + o.PersistDebounce = &debounce + } +} + +// WithRunning set running +func WithRunning(running bool) Option { + return func(o *Options) { + o.Running = running + } +} + +// WithLogger set logger +func WithLogger(logger *slog.Logger) Option { + return func(o *Options) { + o.Logger = logger + } +} diff --git a/pkg/tache/state.go b/pkg/tache/state.go new file mode 100644 index 0000000000000000000000000000000000000000..3b0e058fa44165c8c9bdb60aee875391362eda6b --- /dev/null +++ b/pkg/tache/state.go @@ -0,0 +1,27 @@ +package tache + +// State is the state of a task +type State int + +const ( + // StatePending is the state of a task when it is pending + StatePending = iota + // StateRunning is the state of a task when it is running + StateRunning + // StateSucceeded is the state of a task when it succeeded + StateSucceeded + // StateCanceling is the state of a task when it is canceling + StateCanceling + // StateCanceled is the state of a task when it is canceled + StateCanceled + // StateErrored is the state of a task when it is errored (it will be retried) + StateErrored + // StateFailing is the state of a task when it is failing (executed OnFailed hook) + StateFailing + // StateFailed is the state of a task when it failed (no retry times left) + StateFailed + // StateWaitingRetry is the state of a task when it is waiting for retry + StateWaitingRetry + // StateBeforeRetry is the state of a task when it is executing OnBeforeRetry hook + StateBeforeRetry +) diff --git a/pkg/tache/task.go b/pkg/tache/task.go new file mode 100644 index 0000000000000000000000000000000000000000..7325521a96cb8ceba820d9de883d938d25bed4a7 --- /dev/null +++ b/pkg/tache/task.go @@ -0,0 +1,59 @@ +package tache + +import "context" + +// TaskBase is the base interface for all tasks +type TaskBase interface { + // SetProgress sets the progress of the task + SetProgress(progress float64) + // GetProgress gets the progress of the task + GetProgress() float64 + // SetState sets the state of the task + SetState(state State) + // GetState gets the state of the task + GetState() State + // GetID gets the ID of the task + GetID() string + // SetID sets the ID of the task + SetID(id string) + // SetErr sets the error of the task + SetErr(err error) + // GetErr gets the error of the task + GetErr() error + // SetCtx sets the context of the task + SetCtx(ctx context.Context) + // CtxDone gets the context done channel of the task + CtxDone() <-chan struct{} + // Cancel cancels the task + Cancel() + // Ctx gets the context of the task + Ctx() context.Context + // SetCancelFunc sets the cancel function of the task + SetCancelFunc(cancelFunc context.CancelFunc) + // GetRetry gets the retry of the task + GetRetry() (int, int) + // SetRetry sets the retry of the task + SetRetry(retry int, maxRetry int) + SetSize(size int64) + GetSize() int64 + // Persist persists the task + Persist() + // SetPersist sets the persist function of the task + SetPersist(persist func()) +} + +type Info interface { + GetName() string + GetStatus() string +} + +// Task is the interface for all tasks +type Task interface { + TaskBase + Run() error +} + +type TaskWithInfo interface { + Task + Info +} diff --git a/pkg/tache/utils.go b/pkg/tache/utils.go new file mode 100644 index 0000000000000000000000000000000000000000..09f439e3c7de684c606a14a20c29003f013d3dfe --- /dev/null +++ b/pkg/tache/utils.go @@ -0,0 +1,72 @@ +package tache + +import ( + "runtime" + "sync" + "time" +) + +// sliceContains checks if a slice contains a value +func sliceContains[T comparable](slice []T, v T) bool { + for _, vv := range slice { + if vv == v { + return true + } + } + return false +} + +// getCurrentGoroutineStack get current goroutine stack +func getCurrentGoroutineStack() string { + buf := make([]byte, 1<<16) + n := runtime.Stack(buf, false) + return string(buf[:n]) +} + +// newDebounce returns a debounced function +func newDebounce(f func(), interval time.Duration) func() { + var timer *time.Timer + var lock sync.Mutex + return func() { + lock.Lock() + defer lock.Unlock() + if timer == nil { + timer = time.AfterFunc(interval, f) + } else { + timer.Reset(interval) + } + } +} + +// isRetry checks if a task is retry executed +func isRetry[T Task](task T) bool { + return task.GetState() == StateWaitingRetry +} + +// isLastRetry checks if a task is last retry +func isLastRetry[T Task](task T) bool { + retry, maxRetry := task.GetRetry() + return retry >= maxRetry +} + +// needRetry judge whether the task need retry +func needRetry[T Task](task T) bool { + // if task is not recoverable, return false + if !IsRecoverable(task.GetErr()) { + return false + } + // if task is not retryable, return false + if r, ok := Task(task).(Retryable); ok && !r.Retryable() { + return false + } + // only retry when task is errored or failed + if sliceContains([]State{StateErrored, StateFailed}, task.GetState()) { + retry, maxRetry := task.GetRetry() + if retry < maxRetry { + task.SetRetry(retry+1, maxRetry) + task.SetState(StateWaitingRetry) + return true + } + } + return false +} diff --git a/pkg/tache/worker.go b/pkg/tache/worker.go new file mode 100644 index 0000000000000000000000000000000000000000..ecf9cfbc1c498542e37e9321e38c7ff594a4bfb5 --- /dev/null +++ b/pkg/tache/worker.go @@ -0,0 +1,92 @@ +package tache + +import ( + "context" + "errors" + "fmt" + "log" + "sync/atomic" +) + +// Worker is the worker to execute task +type Worker[T Task] struct { + ID int +} + +// Execute executes the task +func (w Worker[T]) Execute(task T) { + if isRetry(task) { + task.SetState(StateBeforeRetry) + if hook, ok := Task(task).(OnBeforeRetry); ok { + hook.OnBeforeRetry() + } + } + onError := func(err error) { + task.SetErr(err) + if errors.Is(err, context.Canceled) { + task.SetState(StateCanceled) + } else { + task.SetState(StateErrored) + } + if !needRetry(task) { + if hook, ok := Task(task).(OnFailed); ok { + task.SetState(StateFailing) + hook.OnFailed() + } + task.SetState(StateFailed) + } + } + defer func() { + if err := recover(); err != nil { + log.Printf("error [%s] while run task [%s],stack trace:\n%s", err, task.GetID(), getCurrentGoroutineStack()) + onError(NewErr(fmt.Sprintf("panic: %v", err))) + } + }() + task.SetState(StateRunning) + err := task.Run() + if err != nil { + onError(err) + return + } + task.SetState(StateSucceeded) + if onSucceeded, ok := Task(task).(OnSucceeded); ok { + onSucceeded.OnSucceeded() + } + task.SetErr(nil) +} + +// WorkerPool is the pool of workers +type WorkerPool[T Task] struct { + working atomic.Int64 + workers chan *Worker[T] +} + +// NewWorkerPool creates a new worker pool +func NewWorkerPool[T Task](size int) *WorkerPool[T] { + workers := make(chan *Worker[T], size) + for i := 0; i < size; i++ { + workers <- &Worker[T]{ + ID: i, + } + } + return &WorkerPool[T]{ + workers: workers, + } +} + +// Get gets a worker from pool +func (wp *WorkerPool[T]) Get() *Worker[T] { + select { + case worker := <-wp.workers: + wp.working.Add(1) + return worker + default: + return nil + } +} + +// Put puts a worker back to pool +func (wp *WorkerPool[T]) Put(worker *Worker[T]) { + wp.workers <- worker + wp.working.Add(-1) +} diff --git a/pkg/task/errors.go b/pkg/task/errors.go new file mode 100644 index 0000000000000000000000000000000000000000..3f2c4302e154026762cf56345ff0e243fbf8a704 --- /dev/null +++ b/pkg/task/errors.go @@ -0,0 +1,8 @@ +package task + +import "errors" + +var ( + ErrTaskNotFound = errors.New("task not found") + ErrTaskRunning = errors.New("task is running") +) diff --git a/pkg/task/manager.go b/pkg/task/manager.go new file mode 100644 index 0000000000000000000000000000000000000000..cd7c69a225e292bc55d4c7a9421b04a25165eaac --- /dev/null +++ b/pkg/task/manager.go @@ -0,0 +1,145 @@ +package task + +import ( + "github.com/alist-org/alist/v3/pkg/generic_sync" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +type Manager[K comparable] struct { + curID K + workerC chan struct{} + updateID func(*K) + tasks generic_sync.MapOf[K, *Task[K]] +} + +func (tm *Manager[K]) Submit(task *Task[K]) K { + if tm.updateID != nil { + tm.updateID(&tm.curID) + task.ID = tm.curID + } + tm.tasks.Store(task.ID, task) + tm.do(task) + return task.ID +} + +func (tm *Manager[K]) do(task *Task[K]) { + go func() { + log.Debugf("task [%s] waiting for worker", task.Name) + select { + case <-tm.workerC: + log.Debugf("task [%s] starting", task.Name) + task.run() + log.Debugf("task [%s] ended", task.Name) + case <-task.Ctx.Done(): + log.Debugf("task [%s] canceled", task.Name) + return + } + // return worker + tm.workerC <- struct{}{} + }() +} + +func (tm *Manager[K]) GetAll() []*Task[K] { + return tm.tasks.Values() +} + +func (tm *Manager[K]) Get(tid K) (*Task[K], bool) { + return tm.tasks.Load(tid) +} + +func (tm *Manager[K]) MustGet(tid K) *Task[K] { + task, _ := tm.Get(tid) + return task +} + +func (tm *Manager[K]) Retry(tid K) error { + t, ok := tm.Get(tid) + if !ok { + return errors.WithStack(ErrTaskNotFound) + } + tm.do(t) + return nil +} + +func (tm *Manager[K]) Cancel(tid K) error { + t, ok := tm.Get(tid) + if !ok { + return errors.WithStack(ErrTaskNotFound) + } + t.Cancel() + return nil +} + +func (tm *Manager[K]) Remove(tid K) error { + t, ok := tm.Get(tid) + if !ok { + return errors.WithStack(ErrTaskNotFound) + } + if !t.Done() { + return errors.WithStack(ErrTaskRunning) + } + tm.tasks.Delete(tid) + return nil +} + +// RemoveAll removes all tasks from the manager, this maybe shouldn't be used +// because the task maybe still running. +func (tm *Manager[K]) RemoveAll() { + tm.tasks.Clear() +} + +func (tm *Manager[K]) RemoveByStates(states ...string) { + tasks := tm.GetAll() + for _, task := range tasks { + if utils.SliceContains(states, task.GetState()) { + _ = tm.Remove(task.ID) + } + } +} + +func (tm *Manager[K]) GetByStates(states ...string) []*Task[K] { + var tasks []*Task[K] + tm.tasks.Range(func(key K, value *Task[K]) bool { + if utils.SliceContains(states, value.GetState()) { + tasks = append(tasks, value) + } + return true + }) + return tasks +} + +func (tm *Manager[K]) ListUndone() []*Task[K] { + return tm.GetByStates(PENDING, RUNNING, CANCELING) +} + +func (tm *Manager[K]) ListDone() []*Task[K] { + return tm.GetByStates(SUCCEEDED, CANCELED, ERRORED) +} + +func (tm *Manager[K]) ClearDone() { + tm.RemoveByStates(SUCCEEDED, CANCELED, ERRORED) +} + +func (tm *Manager[K]) ClearSucceeded() { + tm.RemoveByStates(SUCCEEDED) +} + +func (tm *Manager[K]) RawTasks() *generic_sync.MapOf[K, *Task[K]] { + return &tm.tasks +} + +func NewTaskManager[K comparable](maxWorker int, updateID ...func(*K)) *Manager[K] { + tm := &Manager[K]{ + tasks: generic_sync.MapOf[K, *Task[K]]{}, + workerC: make(chan struct{}, maxWorker), + } + for i := 0; i < maxWorker; i++ { + tm.workerC <- struct{}{} + } + if len(updateID) > 0 { + tm.updateID = updateID[0] + } + return tm +} diff --git a/pkg/task/task.go b/pkg/task/task.go new file mode 100644 index 0000000000000000000000000000000000000000..5b634f10cdb4499d3274ed9be4414003f8e48117 --- /dev/null +++ b/pkg/task/task.go @@ -0,0 +1,124 @@ +// Package task manage task, such as file upload, file copy between storages, offline download, etc. +package task + +import ( + "context" + "runtime" + + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +var ( + PENDING = "pending" + RUNNING = "running" + SUCCEEDED = "succeeded" + CANCELING = "canceling" + CANCELED = "canceled" + ERRORED = "errored" +) + +type Func[K comparable] func(task *Task[K]) error +type Callback[K comparable] func(task *Task[K]) + +type Task[K comparable] struct { + ID K + Name string + state string // pending, running, finished, canceling, canceled, errored + status string + progress float64 + + Error error + + Func Func[K] + callback Callback[K] + + Ctx context.Context + cancel context.CancelFunc +} + +func (t *Task[K]) SetStatus(status string) { + t.status = status +} + +func (t *Task[K]) SetProgress(percentage float64) { + t.progress = percentage +} + +func (t Task[K]) GetProgress() float64 { + return t.progress +} + +func (t Task[K]) GetState() string { + return t.state +} + +func (t Task[K]) GetStatus() string { + return t.status +} + +func (t Task[K]) GetErrMsg() string { + if t.Error == nil { + return "" + } + return t.Error.Error() +} + +func getCurrentGoroutineStack() string { + buf := make([]byte, 1<<16) + n := runtime.Stack(buf, false) + return string(buf[:n]) +} + +func (t *Task[K]) run() { + t.state = RUNNING + defer func() { + if err := recover(); err != nil { + log.Errorf("error [%s] while run task [%s],stack trace:\n%s", err, t.Name, getCurrentGoroutineStack()) + t.Error = errors.Errorf("panic: %+v", err) + t.state = ERRORED + } + }() + t.Error = t.Func(t) + if t.Error != nil { + log.Errorf("error [%+v] while run task [%s]", t.Error, t.Name) + } + if errors.Is(t.Ctx.Err(), context.Canceled) { + t.state = CANCELED + } else if t.Error != nil { + t.state = ERRORED + } else { + t.state = SUCCEEDED + t.SetProgress(100) + if t.callback != nil { + t.callback(t) + } + } +} + +func (t *Task[K]) retry() { + t.run() +} + +func (t *Task[K]) Done() bool { + return t.state == SUCCEEDED || t.state == CANCELED || t.state == ERRORED +} + +func (t *Task[K]) Cancel() { + if t.state == SUCCEEDED || t.state == CANCELED { + return + } + if t.cancel != nil { + t.cancel() + } + // maybe can't cancel + t.state = CANCELING +} + +func WithCancelCtx[K comparable](task *Task[K]) *Task[K] { + ctx, cancel := context.WithCancel(context.Background()) + task.Ctx = ctx + task.cancel = cancel + task.state = PENDING + return task +} diff --git a/pkg/task/task_test.go b/pkg/task/task_test.go new file mode 100644 index 0000000000000000000000000000000000000000..42236ca8a406635c44e711297ba31bb819d6b470 --- /dev/null +++ b/pkg/task/task_test.go @@ -0,0 +1,96 @@ +package task + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/pkg/errors" +) + +func TestTask_Manager(t *testing.T) { + tm := NewTaskManager(3, func(id *uint64) { + atomic.AddUint64(id, 1) + }) + id := tm.Submit(WithCancelCtx(&Task[uint64]{ + Name: "test", + Func: func(task *Task[uint64]) error { + time.Sleep(time.Millisecond * 500) + return nil + }, + })) + task, ok := tm.Get(id) + if !ok { + t.Fatal("task not found") + } + time.Sleep(time.Millisecond * 100) + if task.state != RUNNING { + t.Errorf("task status not running: %s", task.state) + } + time.Sleep(time.Second) + if task.state != SUCCEEDED { + t.Errorf("task status not finished: %s", task.state) + } +} + +func TestTask_Cancel(t *testing.T) { + tm := NewTaskManager(3, func(id *uint64) { + atomic.AddUint64(id, 1) + }) + id := tm.Submit(WithCancelCtx(&Task[uint64]{ + Name: "test", + Func: func(task *Task[uint64]) error { + for { + if utils.IsCanceled(task.Ctx) { + return nil + } else { + t.Logf("task is running") + } + } + }, + })) + task, ok := tm.Get(id) + if !ok { + t.Fatal("task not found") + } + time.Sleep(time.Microsecond * 50) + task.Cancel() + time.Sleep(time.Millisecond) + if task.state != CANCELED { + t.Errorf("task status not canceled: %s", task.state) + } +} + +func TestTask_Retry(t *testing.T) { + tm := NewTaskManager(3, func(id *uint64) { + atomic.AddUint64(id, 1) + }) + num := 0 + id := tm.Submit(WithCancelCtx(&Task[uint64]{ + Name: "test", + Func: func(task *Task[uint64]) error { + num++ + if num&1 == 1 { + return errors.New("test error") + } + return nil + }, + })) + task, ok := tm.Get(id) + if !ok { + t.Fatal("task not found") + } + time.Sleep(time.Millisecond) + if task.Error == nil { + t.Error(task.state) + t.Fatal("task error is nil, but expected error") + } else { + t.Logf("task error: %s", task.Error) + } + task.retry() + time.Sleep(time.Millisecond) + if task.Error != nil { + t.Errorf("task error: %+v, but expected nil", task.Error) + } +} diff --git a/pkg/utils/balance.go b/pkg/utils/balance.go new file mode 100644 index 0000000000000000000000000000000000000000..700d8c1cd1782f49b23d2c05cc5fd33c7211d524 --- /dev/null +++ b/pkg/utils/balance.go @@ -0,0 +1,18 @@ +package utils + +import "strings" + +var balance = ".balance" + +func IsBalance(str string) bool { + return strings.Contains(str, balance) +} + +// GetActualMountPath remove balance suffix +func GetActualMountPath(mountPath string) string { + bIndex := strings.LastIndex(mountPath, ".balance") + if bIndex != -1 { + mountPath = mountPath[:bIndex] + } + return mountPath +} diff --git a/pkg/utils/bool.go b/pkg/utils/bool.go new file mode 100644 index 0000000000000000000000000000000000000000..eecf550b0681aebc7f4739bde9a182272fee0b19 --- /dev/null +++ b/pkg/utils/bool.go @@ -0,0 +1,5 @@ +package utils + +func IsBool(bs ...bool) bool { + return len(bs) > 0 && bs[0] +} diff --git a/pkg/utils/ctx.go b/pkg/utils/ctx.go new file mode 100644 index 0000000000000000000000000000000000000000..89f27f8eb3e08638cd2a509d46ff17e8cbedf8e3 --- /dev/null +++ b/pkg/utils/ctx.go @@ -0,0 +1,14 @@ +package utils + +import ( + "context" +) + +func IsCanceled(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + return false + } +} diff --git a/pkg/utils/email.go b/pkg/utils/email.go new file mode 100644 index 0000000000000000000000000000000000000000..833a70395185bc5e9c160d11c0a803a6ae9b23c4 --- /dev/null +++ b/pkg/utils/email.go @@ -0,0 +1,9 @@ +package utils + +import "regexp" + +func IsEmailFormat(email string) bool { + pattern := `^[0-9a-z][_.0-9a-z-]{0,31}@([0-9a-z][0-9a-z-]{0,30}[0-9a-z]\.){1,4}[a-z]{2,4}$` + reg := regexp.MustCompile(pattern) + return reg.MatchString(email) +} diff --git a/pkg/utils/file.go b/pkg/utils/file.go new file mode 100644 index 0000000000000000000000000000000000000000..54247636dcbd55bf382285589d49abc6e3395a2d --- /dev/null +++ b/pkg/utils/file.go @@ -0,0 +1,187 @@ +package utils + +import ( + "fmt" + "io" + "mime" + "os" + "path" + "path/filepath" + "strings" + + "github.com/alist-org/alist/v3/internal/errs" + + "github.com/alist-org/alist/v3/internal/conf" + log "github.com/sirupsen/logrus" +) + +// CopyFile File copies a single file from src to dst +func CopyFile(src, dst string) error { + var err error + var srcfd *os.File + var dstfd *os.File + var srcinfo os.FileInfo + + if srcfd, err = os.Open(src); err != nil { + return err + } + defer srcfd.Close() + + if dstfd, err = CreateNestedFile(dst); err != nil { + return err + } + defer dstfd.Close() + + if _, err = CopyWithBuffer(dstfd, srcfd); err != nil { + return err + } + if srcinfo, err = os.Stat(src); err != nil { + return err + } + return os.Chmod(dst, srcinfo.Mode()) +} + +// CopyDir Dir copies a whole directory recursively +func CopyDir(src, dst string) error { + var err error + var fds []os.DirEntry + var srcinfo os.FileInfo + + if srcinfo, err = os.Stat(src); err != nil { + return err + } + if err = os.MkdirAll(dst, srcinfo.Mode()); err != nil { + return err + } + if fds, err = os.ReadDir(src); err != nil { + return err + } + for _, fd := range fds { + srcfp := path.Join(src, fd.Name()) + dstfp := path.Join(dst, fd.Name()) + + if fd.IsDir() { + if err = CopyDir(srcfp, dstfp); err != nil { + fmt.Println(err) + } + } else { + if err = CopyFile(srcfp, dstfp); err != nil { + fmt.Println(err) + } + } + } + return nil +} + +// SymlinkOrCopyFile symlinks a file or copy if symlink failed +func SymlinkOrCopyFile(src, dst string) error { + if err := CreateNestedDirectory(filepath.Dir(dst)); err != nil { + return err + } + if err := os.Symlink(src, dst); err != nil { + return CopyFile(src, dst) + } + return nil +} + +// Exists determine whether the file exists +func Exists(name string) bool { + if _, err := os.Stat(name); err != nil { + if os.IsNotExist(err) { + return false + } + } + return true +} + +// CreateNestedDirectory create nested directory +func CreateNestedDirectory(path string) error { + err := os.MkdirAll(path, 0700) + if err != nil { + log.Errorf("can't create folder, %s", err) + } + return err +} + +// CreateNestedFile create nested file +func CreateNestedFile(path string) (*os.File, error) { + basePath := filepath.Dir(path) + if err := CreateNestedDirectory(basePath); err != nil { + return nil, err + } + return os.Create(path) +} + +// CreateTempFile create temp file from io.ReadCloser, and seek to 0 +func CreateTempFile(r io.Reader, size int64) (*os.File, error) { + if f, ok := r.(*os.File); ok { + return f, nil + } + f, err := os.CreateTemp(conf.Conf.TempDir, "file-*") + if err != nil { + return nil, err + } + readBytes, err := CopyWithBuffer(f, r) + if err != nil { + _ = os.Remove(f.Name()) + return nil, errs.NewErr(err, "CreateTempFile failed") + } + if size > 0 && readBytes != size { + _ = os.Remove(f.Name()) + return nil, errs.NewErr(err, "CreateTempFile failed, incoming stream actual size= %d, expect = %d ", readBytes, size) + } + _, err = f.Seek(0, io.SeekStart) + if err != nil { + _ = os.Remove(f.Name()) + return nil, errs.NewErr(err, "CreateTempFile failed, can't seek to 0 ") + } + return f, nil +} + +// GetFileType get file type +func GetFileType(filename string) int { + ext := strings.ToLower(Ext(filename)) + if SliceContains(conf.SlicesMap[conf.AudioTypes], ext) { + return conf.AUDIO + } + if SliceContains(conf.SlicesMap[conf.VideoTypes], ext) { + return conf.VIDEO + } + if SliceContains(conf.SlicesMap[conf.ImageTypes], ext) { + return conf.IMAGE + } + if SliceContains(conf.SlicesMap[conf.TextTypes], ext) { + return conf.TEXT + } + return conf.UNKNOWN +} + +func GetObjType(filename string, isDir bool) int { + if isDir { + return conf.FOLDER + } + return GetFileType(filename) +} + +var extraMimeTypes = map[string]string{ + ".apk": "application/vnd.android.package-archive", +} + +func GetMimeType(name string) string { + ext := path.Ext(name) + if m, ok := extraMimeTypes[ext]; ok { + return m + } + m := mime.TypeByExtension(ext) + if m != "" { + return m + } + return "application/octet-stream" +} + +const ( + KB = 1 << (10 * (iota + 1)) + MB + GB + TB +) diff --git a/pkg/utils/hash.go b/pkg/utils/hash.go new file mode 100644 index 0000000000000000000000000000000000000000..fa06bcc24c2c072ed836ccc76db7f9fbd7743da1 --- /dev/null +++ b/pkg/utils/hash.go @@ -0,0 +1,228 @@ +package utils + +import ( + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "encoding" + "encoding/hex" + "encoding/json" + "errors" + "hash" + "io" + + "github.com/alist-org/alist/v3/internal/errs" + log "github.com/sirupsen/logrus" +) + +func GetMD5EncodeStr(data string) string { + return HashData(MD5, []byte(data)) +} + +//inspired by "github.com/rclone/rclone/fs/hash" + +// ErrUnsupported should be returned by filesystem, +// if it is requested to deliver an unsupported hash type. +var ErrUnsupported = errors.New("hash type not supported") + +// HashType indicates a standard hashing algorithm +type HashType struct { + Width int + Name string + Alias string + NewFunc func(...any) hash.Hash +} + +func (ht *HashType) MarshalJSON() ([]byte, error) { + return []byte(`"` + ht.Name + `"`), nil +} + +func (ht *HashType) MarshalText() (text []byte, err error) { + return []byte(ht.Name), nil +} + +var ( + _ json.Marshaler = (*HashType)(nil) + //_ json.Unmarshaler = (*HashType)(nil) + + // read/write from/to json keys + _ encoding.TextMarshaler = (*HashType)(nil) + //_ encoding.TextUnmarshaler = (*HashType)(nil) +) + +var ( + name2hash = map[string]*HashType{} + alias2hash = map[string]*HashType{} + Supported []*HashType +) + +// RegisterHash adds a new Hash to the list and returns its Type +func RegisterHash(name, alias string, width int, newFunc func() hash.Hash) *HashType { + return RegisterHashWithParam(name, alias, width, func(a ...any) hash.Hash { return newFunc() }) +} + +func RegisterHashWithParam(name, alias string, width int, newFunc func(...any) hash.Hash) *HashType { + newType := &HashType{ + Name: name, + Alias: alias, + Width: width, + NewFunc: newFunc, + } + + name2hash[name] = newType + alias2hash[alias] = newType + Supported = append(Supported, newType) + return newType +} + +var ( + // MD5 indicates MD5 support + MD5 = RegisterHash("md5", "MD5", 32, md5.New) + + // SHA1 indicates SHA-1 support + SHA1 = RegisterHash("sha1", "SHA-1", 40, sha1.New) + + // SHA256 indicates SHA-256 support + SHA256 = RegisterHash("sha256", "SHA-256", 64, sha256.New) +) + +// HashData get hash of one hashType +func HashData(hashType *HashType, data []byte, params ...any) string { + h := hashType.NewFunc(params...) + h.Write(data) + return hex.EncodeToString(h.Sum(nil)) +} + +// HashReader get hash of one hashType from a reader +func HashReader(hashType *HashType, reader io.Reader, params ...any) (string, error) { + h := hashType.NewFunc(params...) + _, err := CopyWithBuffer(h, reader) + if err != nil { + return "", errs.NewErr(err, "HashReader error") + } + return hex.EncodeToString(h.Sum(nil)), nil +} + +// HashFile get hash of one hashType from a model.File +func HashFile(hashType *HashType, file io.ReadSeeker, params ...any) (string, error) { + str, err := HashReader(hashType, file, params...) + if err != nil { + return "", err + } + if _, err = file.Seek(0, io.SeekStart); err != nil { + return str, err + } + return str, nil +} + +// fromTypes will return hashers for all the requested types. +func fromTypes(types []*HashType) map[*HashType]hash.Hash { + hashers := map[*HashType]hash.Hash{} + for _, t := range types { + hashers[t] = t.NewFunc() + } + return hashers +} + +// toMultiWriter will return a set of hashers into a +// single multiwriter, where one write will update all +// the hashers. +func toMultiWriter(h map[*HashType]hash.Hash) io.Writer { + // Convert to to slice + var w = make([]io.Writer, 0, len(h)) + for _, v := range h { + w = append(w, v) + } + return io.MultiWriter(w...) +} + +// A MultiHasher will construct various hashes on all incoming writes. +type MultiHasher struct { + w io.Writer + size int64 + h map[*HashType]hash.Hash // Hashes +} + +// NewMultiHasher will return a hash writer that will write +// the requested hash types. +func NewMultiHasher(types []*HashType) *MultiHasher { + hashers := fromTypes(types) + m := MultiHasher{h: hashers, w: toMultiWriter(hashers)} + return &m +} + +func (m *MultiHasher) Write(p []byte) (n int, err error) { + n, err = m.w.Write(p) + m.size += int64(n) + return n, err +} + +func (m *MultiHasher) GetHashInfo() *HashInfo { + dst := make(map[*HashType]string) + for k, v := range m.h { + dst[k] = hex.EncodeToString(v.Sum(nil)) + } + return &HashInfo{h: dst} +} + +// Sum returns the specified hash from the multihasher +func (m *MultiHasher) Sum(hashType *HashType) ([]byte, error) { + h, ok := m.h[hashType] + if !ok { + return nil, ErrUnsupported + } + return h.Sum(nil), nil +} + +// Size returns the number of bytes written +func (m *MultiHasher) Size() int64 { + return m.size +} + +// A HashInfo contains hash string for one or more hashType +type HashInfo struct { + h map[*HashType]string `json:"hashInfo"` +} + +func NewHashInfoByMap(h map[*HashType]string) HashInfo { + return HashInfo{h} +} + +func NewHashInfo(ht *HashType, str string) HashInfo { + m := make(map[*HashType]string) + if ht != nil { + m[ht] = str + } + return HashInfo{h: m} +} + +func (hi HashInfo) String() string { + result, err := json.Marshal(hi.h) + if err != nil { + return "" + } + return string(result) +} +func FromString(str string) HashInfo { + hi := NewHashInfo(nil, "") + var tmp map[string]string + err := json.Unmarshal([]byte(str), &tmp) + if err != nil { + log.Warnf("failed to unmarsh HashInfo from string=%s", str) + } else { + for k, v := range tmp { + if name2hash[k] != nil && len(v) > 0 { + hi.h[name2hash[k]] = v + } + } + } + + return hi +} +func (hi HashInfo) GetHash(ht *HashType) string { + return hi.h[ht] +} + +func (hi HashInfo) Export() map[*HashType]string { + return hi.h +} diff --git a/pkg/utils/hash/gcid.go b/pkg/utils/hash/gcid.go new file mode 100644 index 0000000000000000000000000000000000000000..f6eccef7aeab4474777dbe65a8923ed25bc6684d --- /dev/null +++ b/pkg/utils/hash/gcid.go @@ -0,0 +1,98 @@ +package hash_extend + +import ( + "crypto/sha1" + "encoding" + "fmt" + "hash" + "strconv" + + "github.com/alist-org/alist/v3/pkg/utils" +) + +var GCID = utils.RegisterHashWithParam("gcid", "GCID", 40, func(a ...any) hash.Hash { + var ( + size int64 + err error + ) + if len(a) > 0 { + size, err = strconv.ParseInt(fmt.Sprint(a[0]), 10, 64) + if err != nil { + panic(err) + } + } + return NewGcid(size) +}) + +func NewGcid(size int64) hash.Hash { + calcBlockSize := func(j int64) int64 { + var psize int64 = 0x40000 + for float64(j)/float64(psize) > 0x200 && psize < 0x200000 { + psize = psize << 1 + } + return psize + } + + return &gcid{ + hash: sha1.New(), + hashState: sha1.New(), + blockSize: int(calcBlockSize(size)), + } +} + +type gcid struct { + hash hash.Hash + hashState hash.Hash + blockSize int + + offset int +} + +func (h *gcid) Write(p []byte) (n int, err error) { + n = len(p) + for len(p) > 0 { + if h.offset < h.blockSize { + var lastSize = h.blockSize - h.offset + if lastSize > len(p) { + lastSize = len(p) + } + + h.hashState.Write(p[:lastSize]) + h.offset += lastSize + p = p[lastSize:] + } + + if h.offset >= h.blockSize { + h.hash.Write(h.hashState.Sum(nil)) + h.hashState.Reset() + h.offset = 0 + } + } + return +} + +func (h *gcid) Sum(b []byte) []byte { + if h.offset != 0 { + if hashm, ok := h.hash.(encoding.BinaryMarshaler); ok { + if hashum, ok := h.hash.(encoding.BinaryUnmarshaler); ok { + tempData, _ := hashm.MarshalBinary() + defer hashum.UnmarshalBinary(tempData) + h.hash.Write(h.hashState.Sum(nil)) + } + } + } + return h.hash.Sum(b) +} + +func (h *gcid) Reset() { + h.hash.Reset() + h.hashState.Reset() +} + +func (h *gcid) Size() int { + return h.hash.Size() +} + +func (h *gcid) BlockSize() int { + return h.blockSize +} diff --git a/pkg/utils/hash_test.go b/pkg/utils/hash_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0f5a2a3b14e62210babff04a01263f524093f382 --- /dev/null +++ b/pkg/utils/hash_test.go @@ -0,0 +1,66 @@ +package utils + +import ( + "bytes" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" +) + +type hashTest struct { + input []byte + output map[*HashType]string +} + +var hashTestSet = []hashTest{ + { + input: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}, + output: map[*HashType]string{ + MD5: "bf13fc19e5151ac57d4252e0e0f87abe", + SHA1: "3ab6543c08a75f292a5ecedac87ec41642d12166", + SHA256: "c839e57675862af5c21bd0a15413c3ec579e0d5522dab600bc6c3489b05b8f54", + }, + }, + // Empty data set + { + input: []byte{}, + output: map[*HashType]string{ + MD5: "d41d8cd98f00b204e9800998ecf8427e", + SHA1: "da39a3ee5e6b4b0d3255bfef95601890afd80709", + SHA256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + }, + }, +} + +func TestMultiHasher(t *testing.T) { + for _, test := range hashTestSet { + mh := NewMultiHasher([]*HashType{MD5, SHA1, SHA256}) + n, err := CopyWithBuffer(mh, bytes.NewBuffer(test.input)) + require.NoError(t, err) + assert.Len(t, test.input, int(n)) + hashInfo := mh.GetHashInfo() + for k, v := range hashInfo.h { + expect, ok := test.output[k] + require.True(t, ok, "test output for hash not found") + assert.Equal(t, expect, v) + } + // Test that all are present + for k, v := range test.output { + expect, ok := hashInfo.h[k] + require.True(t, ok, "test output for hash not found") + assert.Equal(t, expect, v) + } + for k, v := range test.output { + expect := hashInfo.GetHash(k) + require.True(t, len(expect) > 0, "test output for hash not found") + assert.Equal(t, expect, v) + } + expect := hashInfo.GetHash(nil) + require.True(t, len(expect) == 0, "unknown type should return empty string") + str := hashInfo.String() + Log.Info("str=" + str) + newHi := FromString(str) + assert.Equal(t, newHi.h, hashInfo.h) + + } +} diff --git a/pkg/utils/io.go b/pkg/utils/io.go new file mode 100644 index 0000000000000000000000000000000000000000..7be989c3fd788355902222d98582400bd67866e5 --- /dev/null +++ b/pkg/utils/io.go @@ -0,0 +1,235 @@ +package utils + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "sync" + "time" + + "golang.org/x/exp/constraints" + + log "github.com/sirupsen/logrus" +) + +// here is some syntaxic sugar inspired by the Tomas Senart's video, +// it allows me to inline the Reader interface +type readerFunc func(p []byte) (n int, err error) + +func (rf readerFunc) Read(p []byte) (n int, err error) { return rf(p) } + +// CopyWithCtx slightly modified function signature: +// - context has been added in order to propagate cancellation +// - I do not return the number of bytes written, has it is not useful in my use case +func CopyWithCtx(ctx context.Context, out io.Writer, in io.Reader, size int64, progress func(percentage float64)) error { + // Copy will call the Reader and Writer interface multiple time, in order + // to copy by chunk (avoiding loading the whole file in memory). + // I insert the ability to cancel before read time as it is the earliest + // possible in the call process. + var finish int64 = 0 + s := size / 100 + _, err := CopyWithBuffer(out, readerFunc(func(p []byte) (int, error) { + // golang non-blocking channel: https://gobyexample.com/non-blocking-channel-operations + select { + // if context has been canceled + case <-ctx.Done(): + // stop process and propagate "context canceled" error + return 0, ctx.Err() + default: + // otherwise just run default io.Reader implementation + n, err := in.Read(p) + if s > 0 && (err == nil || err == io.EOF) { + finish += int64(n) + progress(float64(finish) / float64(s)) + } + return n, err + } + })) + return err +} + +type limitWriter struct { + w io.Writer + limit int64 +} + +func (l *limitWriter) Write(p []byte) (n int, err error) { + lp := len(p) + if l.limit > 0 { + if int64(lp) > l.limit { + p = p[:l.limit] + } + l.limit -= int64(len(p)) + _, err = l.w.Write(p) + } + return lp, err +} + +func LimitWriter(w io.Writer, limit int64) io.Writer { + return &limitWriter{w: w, limit: limit} +} + +type ReadCloser struct { + io.Reader + io.Closer +} + +type CloseFunc func() error + +func (c CloseFunc) Close() error { + return c() +} + +func NewReadCloser(reader io.Reader, close CloseFunc) io.ReadCloser { + return ReadCloser{ + Reader: reader, + Closer: close, + } +} + +func NewLimitReadCloser(reader io.Reader, close CloseFunc, limit int64) io.ReadCloser { + return NewReadCloser(io.LimitReader(reader, limit), close) +} + +type MultiReadable struct { + originReader io.Reader + reader io.Reader + cache *bytes.Buffer +} + +func NewMultiReadable(reader io.Reader) *MultiReadable { + return &MultiReadable{ + originReader: reader, + reader: reader, + } +} + +func (mr *MultiReadable) Read(p []byte) (int, error) { + n, err := mr.reader.Read(p) + if _, ok := mr.reader.(io.Seeker); !ok && n > 0 { + if mr.cache == nil { + mr.cache = &bytes.Buffer{} + } + mr.cache.Write(p[:n]) + } + return n, err +} + +func (mr *MultiReadable) Reset() error { + if seeker, ok := mr.reader.(io.Seeker); ok { + _, err := seeker.Seek(0, io.SeekStart) + return err + } + if mr.cache != nil && mr.cache.Len() > 0 { + mr.reader = io.MultiReader(mr.cache, mr.reader) + mr.cache = nil + } + return nil +} + +func (mr *MultiReadable) Close() error { + if closer, ok := mr.originReader.(io.Closer); ok { + return closer.Close() + } + return nil +} + +func Retry(attempts int, sleep time.Duration, f func() error) (err error) { + for i := 0; i < attempts; i++ { + fmt.Println("This is attempt number", i) + if i > 0 { + log.Println("retrying after error:", err) + time.Sleep(sleep) + sleep *= 2 + } + err = f() + if err == nil { + return nil + } + } + return fmt.Errorf("after %d attempts, last error: %s", attempts, err) +} + +type ClosersIF interface { + io.Closer + Add(closer io.Closer) + AddClosers(closers Closers) + GetClosers() Closers +} + +type Closers struct { + closers []io.Closer +} + +func (c *Closers) GetClosers() Closers { + return *c +} + +var _ ClosersIF = (*Closers)(nil) + +func (c *Closers) Close() error { + var errs []error + for _, closer := range c.closers { + if closer != nil { + errs = append(errs, closer.Close()) + } + } + return errors.Join(errs...) +} +func (c *Closers) Add(closer io.Closer) { + c.closers = append(c.closers, closer) + +} +func (c *Closers) AddClosers(closers Closers) { + c.closers = append(c.closers, closers.closers...) +} + +func EmptyClosers() Closers { + return Closers{[]io.Closer{}} +} +func NewClosers(c ...io.Closer) Closers { + return Closers{c} +} + +func Min[T constraints.Ordered](a, b T) T { + if a < b { + return a + } + return b +} +func Max[T constraints.Ordered](a, b T) T { + if a < b { + return b + } + return a +} + +var IoBuffPool = &sync.Pool{ + New: func() interface{} { + return make([]byte, 32*1024*2) // Two times of size in io package + }, +} + +func CopyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { + buff := IoBuffPool.Get().([]byte) + defer IoBuffPool.Put(buff) + written, err = io.CopyBuffer(dst, src, buff) + if err != nil { + return + } + return written, nil +} + +func CopyWithBufferN(dst io.Writer, src io.Reader, n int64) (written int64, err error) { + written, err = CopyWithBuffer(dst, io.LimitReader(src, n)) + if written == n { + return n, nil + } + if written < n && err == nil { + // src stopped early; must have been EOF. + err = io.EOF + } + return +} diff --git a/pkg/utils/ip.go b/pkg/utils/ip.go new file mode 100644 index 0000000000000000000000000000000000000000..5d108179de798002723f90de205c5fc7e85a3eeb --- /dev/null +++ b/pkg/utils/ip.go @@ -0,0 +1,49 @@ +package utils + +import ( + "net" + "net/http" + "strings" +) + +func ClientIP(r *http.Request) string { + xForwardedFor := r.Header.Get("X-Forwarded-For") + ip := strings.TrimSpace(strings.Split(xForwardedFor, ",")[0]) + if ip != "" { + return ip + } + + ip = strings.TrimSpace(r.Header.Get("X-Real-Ip")) + if ip != "" { + return ip + } + + if ip, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr)); err == nil { + return ip + } + + return "" +} + +func IsLocalIPAddr(ip string) bool { + return IsLocalIP(net.ParseIP(ip)) +} + +func IsLocalIP(ip net.IP) bool { + if ip == nil { + return false + } + if ip.IsLoopback() { + return true + } + + ip4 := ip.To4() + if ip4 == nil { + return false + } + + return ip4[0] == 10 || // 10.0.0.0/8 + (ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31) || // 172.16.0.0/12 + (ip4[0] == 169 && ip4[1] == 254) || // 169.254.0.0/16 + (ip4[0] == 192 && ip4[1] == 168) // 192.168.0.0/16 +} diff --git a/pkg/utils/json.go b/pkg/utils/json.go new file mode 100644 index 0000000000000000000000000000000000000000..769d981817bcb58cbd9215039c854db169b050a2 --- /dev/null +++ b/pkg/utils/json.go @@ -0,0 +1,29 @@ +package utils + +import ( + stdjson "encoding/json" + "os" + + json "github.com/json-iterator/go" + log "github.com/sirupsen/logrus" +) + +var Json = json.ConfigCompatibleWithStandardLibrary + +// WriteJsonToFile write struct to json file +func WriteJsonToFile(dst string, data interface{}, std ...bool) bool { + str, err := json.MarshalIndent(data, "", " ") + if len(std) > 0 && std[0] { + str, err = stdjson.MarshalIndent(data, "", " ") + } + if err != nil { + log.Errorf("failed convert Conf to []byte:%s", err.Error()) + return false + } + err = os.WriteFile(dst, str, 0777) + if err != nil { + log.Errorf("failed to write json file:%s", err.Error()) + return false + } + return true +} diff --git a/pkg/utils/log.go b/pkg/utils/log.go new file mode 100644 index 0000000000000000000000000000000000000000..79cc4f2b45bb4d0dc0475b3d076e7d2174efda49 --- /dev/null +++ b/pkg/utils/log.go @@ -0,0 +1,7 @@ +package utils + +import ( + log "github.com/sirupsen/logrus" +) + +var Log = log.New() diff --git a/pkg/utils/map.go b/pkg/utils/map.go new file mode 100644 index 0000000000000000000000000000000000000000..378ed1c152bb691d280830d89271c61da3d5fa0f --- /dev/null +++ b/pkg/utils/map.go @@ -0,0 +1,11 @@ +package utils + +func MergeMap(mObj ...map[string]interface{}) map[string]interface{} { + newObj := map[string]interface{}{} + for _, m := range mObj { + for k, v := range m { + newObj[k] = v + } + } + return newObj +} diff --git a/pkg/utils/oauth2.go b/pkg/utils/oauth2.go new file mode 100644 index 0000000000000000000000000000000000000000..c1ad161245fdccb1714e47427373161959631395 --- /dev/null +++ b/pkg/utils/oauth2.go @@ -0,0 +1,15 @@ +package utils + +import "golang.org/x/oauth2" + +type tokenSource struct { + fn func() (*oauth2.Token, error) +} + +func (t *tokenSource) Token() (*oauth2.Token, error) { + return t.fn() +} + +func TokenSource(fn func() (*oauth2.Token, error)) oauth2.TokenSource { + return &tokenSource{fn} +} diff --git a/pkg/utils/path.go b/pkg/utils/path.go new file mode 100644 index 0000000000000000000000000000000000000000..c0793a3ec0f807c4706aa0490d82e20a8b788138 --- /dev/null +++ b/pkg/utils/path.go @@ -0,0 +1,96 @@ +package utils + +import ( + "net/url" + stdpath "path" + "strings" + + "github.com/alist-org/alist/v3/internal/errs" +) + +// FixAndCleanPath +// The upper layer of the root directory is still the root directory. +// So ".." And "." will be cleared +// for example +// 1. ".." or "." => "/" +// 2. "../..." or "./..." => "/..." +// 3. "../.x." or "./.x." => "/.x." +// 4. "x//\\y" = > "/z/x" +func FixAndCleanPath(path string) string { + path = strings.ReplaceAll(path, "\\", "/") + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + return stdpath.Clean(path) +} + +// PathAddSeparatorSuffix Add path '/' suffix +// for example /root => /root/ +func PathAddSeparatorSuffix(path string) string { + if !strings.HasSuffix(path, "/") { + path = path + "/" + } + return path +} + +// PathEqual judge path is equal +func PathEqual(path1, path2 string) bool { + return FixAndCleanPath(path1) == FixAndCleanPath(path2) +} + +func IsSubPath(path string, subPath string) bool { + path, subPath = FixAndCleanPath(path), FixAndCleanPath(subPath) + return path == subPath || strings.HasPrefix(subPath, PathAddSeparatorSuffix(path)) +} + +func Ext(path string) string { + ext := stdpath.Ext(path) + if strings.HasPrefix(ext, ".") { + ext = ext[1:] + } + return strings.ToLower(ext) +} + +func EncodePath(path string, all ...bool) string { + seg := strings.Split(path, "/") + toReplace := []struct { + Src string + Dst string + }{ + {Src: "%", Dst: "%25"}, + {"%", "%25"}, + {"?", "%3F"}, + {"#", "%23"}, + } + for i := range seg { + if len(all) > 0 && all[0] { + seg[i] = url.PathEscape(seg[i]) + } else { + for j := range toReplace { + seg[i] = strings.ReplaceAll(seg[i], toReplace[j].Src, toReplace[j].Dst) + } + } + } + return strings.Join(seg, "/") +} + +func JoinBasePath(basePath, reqPath string) (string, error) { + /** relative path: + * 1. .. + * 2. ../ + * 3. /.. + * 4. /../ + * 5. /a/b/.. + */ + if reqPath == ".." || + strings.HasSuffix(reqPath, "/..") || + strings.HasPrefix(reqPath, "../") || + strings.Contains(reqPath, "/../") { + return "", errs.RelativePath + } + return stdpath.Join(FixAndCleanPath(basePath), FixAndCleanPath(reqPath)), nil +} + +func GetFullPath(mountPath, path string) string { + return stdpath.Join(GetActualMountPath(mountPath), path) +} diff --git a/pkg/utils/path_test.go b/pkg/utils/path_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f42f2f8bb5d57f2b28b2dd6115f6be01e5e56366 --- /dev/null +++ b/pkg/utils/path_test.go @@ -0,0 +1,22 @@ +package utils + +import "testing" + +func TestEncodePath(t *testing.T) { + t.Log(EncodePath("http://localhost:5244/d/123#.png")) +} + +func TestFixAndCleanPath(t *testing.T) { + datas := map[string]string{ + "": "/", + ".././": "/", + "../../.../": "/...", + "x//\\y/": "/x/y", + ".././.x/.y/.//..x../..y..": "/.x/.y/..x../..y..", + } + for key, value := range datas { + if FixAndCleanPath(key) != value { + t.Logf("raw %s fix fail", key) + } + } +} diff --git a/pkg/utils/random/random.go b/pkg/utils/random/random.go new file mode 100644 index 0000000000000000000000000000000000000000..65fbf14a0d3feaf611a47e782b1a5d37e0e904b1 --- /dev/null +++ b/pkg/utils/random/random.go @@ -0,0 +1,33 @@ +package random + +import ( + "math/rand" + "time" + + "github.com/google/uuid" +) + +var Rand *rand.Rand + +const letterBytes = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" + +func String(n int) string { + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[Rand.Intn(len(letterBytes))] + } + return string(b) +} + +func Token() string { + return "alist-" + uuid.NewString() + String(64) +} + +func RangeInt64(left, right int64) int64 { + return rand.Int63n(left+right) - left +} + +func init() { + s := rand.NewSource(time.Now().UnixNano()) + Rand = rand.New(s) +} diff --git a/pkg/utils/slice.go b/pkg/utils/slice.go new file mode 100644 index 0000000000000000000000000000000000000000..842995daaf1e9116fe48278de7d53ee22ae108d2 --- /dev/null +++ b/pkg/utils/slice.go @@ -0,0 +1,101 @@ +package utils + +import ( + "strings" + + "github.com/pkg/errors" +) + +// SliceEqual check if two slices are equal +func SliceEqual[T comparable](a, b []T) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +// SliceContains check if slice contains element +func SliceContains[T comparable](arr []T, v T) bool { + for _, vv := range arr { + if vv == v { + return true + } + } + return false +} + +// SliceAllContains check if slice all contains elements +func SliceAllContains[T comparable](arr []T, vs ...T) bool { + vsMap := make(map[T]struct{}) + for _, v := range arr { + vsMap[v] = struct{}{} + } + for _, v := range vs { + if _, ok := vsMap[v]; !ok { + return false + } + } + return true +} + +// SliceConvert convert slice to another type slice +func SliceConvert[S any, D any](srcS []S, convert func(src S) (D, error)) ([]D, error) { + res := make([]D, 0, len(srcS)) + for i := range srcS { + dst, err := convert(srcS[i]) + if err != nil { + return nil, err + } + res = append(res, dst) + } + return res, nil +} + +func MustSliceConvert[S any, D any](srcS []S, convert func(src S) D) []D { + res := make([]D, 0, len(srcS)) + for i := range srcS { + dst := convert(srcS[i]) + res = append(res, dst) + } + return res +} + +func MergeErrors(errs ...error) error { + errStr := strings.Join(MustSliceConvert(errs, func(err error) string { + return err.Error() + }), "\n") + if errStr != "" { + return errors.New(errStr) + } + return nil +} + +func SliceMeet[T1, T2 any](arr []T1, v T2, meet func(item T1, v T2) bool) bool { + for _, item := range arr { + if meet(item, v) { + return true + } + } + return false +} + +func SliceFilter[T any](arr []T, filter func(src T) bool) []T { + res := make([]T, 0, len(arr)) + for _, src := range arr { + if filter(src) { + res = append(res, src) + } + } + return res +} + +func SliceReplace[T any](arr []T, replace func(src T) T) { + for i, src := range arr { + arr[i] = replace(src) + } +} diff --git a/pkg/utils/str.go b/pkg/utils/str.go new file mode 100644 index 0000000000000000000000000000000000000000..e42484dc78bfd83f07a94b6523380dd0b8baa522 --- /dev/null +++ b/pkg/utils/str.go @@ -0,0 +1,42 @@ +package utils + +import ( + "encoding/base64" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" +) + +func MappingName(name string) string { + for k, v := range conf.FilenameCharMap { + name = strings.ReplaceAll(name, k, v) + } + return name +} + +var DEC = map[string]string{ + "-": "+", + "_": "/", + ".": "=", +} + +func SafeAtob(data string) (string, error) { + for k, v := range DEC { + data = strings.ReplaceAll(data, k, v) + } + bytes, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return "", err + } + return string(bytes), err +} + +// GetNoneEmpty returns the first non-empty string, return empty if all empty +func GetNoneEmpty(strArr ...string) string { + for _, s := range strArr { + if len(s) > 0 { + return s + } + } + return "" +} diff --git a/pkg/utils/time.go b/pkg/utils/time.go new file mode 100644 index 0000000000000000000000000000000000000000..aa7069282fbf7aa4aabcefd7a3c537885a6124de --- /dev/null +++ b/pkg/utils/time.go @@ -0,0 +1,64 @@ +package utils + +import ( + "sync" + "time" +) + +var CNLoc = time.FixedZone("UTC", 8*60*60) + +func MustParseCNTime(str string) time.Time { + lastOpTime, _ := time.ParseInLocation("2006-01-02 15:04:05 -07", str+" +08", CNLoc) + return lastOpTime +} + +func NewDebounce(interval time.Duration) func(f func()) { + var timer *time.Timer + var lock sync.Mutex + return func(f func()) { + lock.Lock() + defer lock.Unlock() + if timer != nil { + timer.Stop() + } + timer = time.AfterFunc(interval, f) + } +} + +func NewDebounce2(interval time.Duration, f func()) func() { + var timer *time.Timer + var lock sync.Mutex + return func() { + lock.Lock() + defer lock.Unlock() + if timer == nil { + timer = time.AfterFunc(interval, f) + } + (*time.Timer)(timer).Reset(interval) + } +} + +func NewThrottle(interval time.Duration) func(func()) { + var lastCall time.Time + + return func(fn func()) { + now := time.Now() + if now.Sub(lastCall) < interval { + return + } + time.AfterFunc(interval, fn) + lastCall = now + } +} + +func NewThrottle2(interval time.Duration, fn func()) func() { + var lastCall time.Time + return func() { + now := time.Now() + if now.Sub(lastCall) < interval { + return + } + time.AfterFunc(interval, fn) + lastCall = now + } +} diff --git a/pkg/utils/url.go b/pkg/utils/url.go new file mode 100644 index 0000000000000000000000000000000000000000..16da1519c97e24031aee599e543d801769e43680 --- /dev/null +++ b/pkg/utils/url.go @@ -0,0 +1,21 @@ +package utils + +import ( + "net/url" +) + +func InjectQuery(raw string, query url.Values) (string, error) { + param := query.Encode() + if param == "" { + return raw, nil + } + u, err := url.Parse(raw) + if err != nil { + return "", err + } + joiner := "?" + if u.RawQuery != "" { + joiner = "&" + } + return raw + joiner + param, nil +} diff --git a/public/dist/README.md b/public/dist/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d8709fb57107cc21c0b7dcefeba488f0b2d58154 --- /dev/null +++ b/public/dist/README.md @@ -0,0 +1 @@ +## Put dist of frontend here. \ No newline at end of file diff --git a/public/public.go b/public/public.go new file mode 100644 index 0000000000000000000000000000000000000000..e94146c3b0878744c5d1122378eab6f599eabb8b --- /dev/null +++ b/public/public.go @@ -0,0 +1,6 @@ +package public + +import "embed" + +//go:embed all:dist +var Public embed.FS diff --git a/server/common/auth.go b/server/common/auth.go new file mode 100644 index 0000000000000000000000000000000000000000..b6a79b752aa664ce2b02202a0ee7a08496e6bc87 --- /dev/null +++ b/server/common/auth.go @@ -0,0 +1,55 @@ +package common + +import ( + "time" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/golang-jwt/jwt/v4" + "github.com/pkg/errors" +) + +var SecretKey []byte + +type UserClaims struct { + Username string `json:"username"` + PwdTS int64 `json:"pwd_ts"` + jwt.RegisteredClaims +} + +func GenerateToken(user *model.User) (tokenString string, err error) { + claim := UserClaims{ + Username: user.Username, + PwdTS: user.PwdTS, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(conf.Conf.TokenExpiresIn) * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + }} + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claim) + tokenString, err = token.SignedString(SecretKey) + return tokenString, err +} + +func ParseToken(tokenString string) (*UserClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { + return SecretKey, nil + }) + if err != nil { + if ve, ok := err.(*jwt.ValidationError); ok { + if ve.Errors&jwt.ValidationErrorMalformed != 0 { + return nil, errors.New("that's not even a token") + } else if ve.Errors&jwt.ValidationErrorExpired != 0 { + return nil, errors.New("token is expired") + } else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 { + return nil, errors.New("token not active yet") + } else { + return nil, errors.New("couldn't handle this token") + } + } + } + if claims, ok := token.Claims.(*UserClaims); ok && token.Valid { + return claims, nil + } + return nil, errors.New("couldn't handle this token") +} diff --git a/server/common/base.go b/server/common/base.go new file mode 100644 index 0000000000000000000000000000000000000000..eb6ef2b8ac2f790ad3f464301f263974ed1f49d8 --- /dev/null +++ b/server/common/base.go @@ -0,0 +1,30 @@ +package common + +import ( + "fmt" + "net/http" + stdpath "path" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" +) + +func GetApiUrl(r *http.Request) string { + api := conf.Conf.SiteURL + if strings.HasPrefix(api, "http") { + return api + } + if r != nil { + protocol := "http" + if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { + protocol = "https" + } + host := r.Host + if r.Header.Get("X-Forwarded-Host") != "" { + host = r.Header.Get("X-Forwarded-Host") + } + api = fmt.Sprintf("%s://%s", protocol, stdpath.Join(host, api)) + } + api = strings.TrimSuffix(api, "/") + return api +} diff --git a/server/common/check.go b/server/common/check.go new file mode 100644 index 0000000000000000000000000000000000000000..78051f4ee1e58354223661ab5b5d6532143c1816 --- /dev/null +++ b/server/common/check.go @@ -0,0 +1,74 @@ +package common + +import ( + "path" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/dlclark/regexp2" +) + +func IsStorageSignEnabled(rawPath string) bool { + storage := op.GetBalancedStorage(rawPath) + return storage != nil && storage.GetStorage().EnableSign +} + +func CanWrite(meta *model.Meta, path string) bool { + if meta == nil || !meta.Write { + return false + } + return meta.WSub || meta.Path == path +} + +func IsApply(metaPath, reqPath string, applySub bool) bool { + if utils.PathEqual(metaPath, reqPath) { + return true + } + return utils.IsSubPath(metaPath, reqPath) && applySub +} + +func CanAccess(user *model.User, meta *model.Meta, reqPath string, password string) bool { + // if the reqPath is in hide (only can check the nearest meta) and user can't see hides, can't access + if meta != nil && !user.CanSeeHides() && meta.Hide != "" && + IsApply(meta.Path, path.Dir(reqPath), meta.HSub) { // the meta should apply to the parent of current path + for _, hide := range strings.Split(meta.Hide, "\n") { + re := regexp2.MustCompile(hide, regexp2.None) + if isMatch, _ := re.MatchString(path.Base(reqPath)); isMatch { + return false + } + } + } + // if is not guest and can access without password + if user.CanAccessWithoutPassword() { + return true + } + // if meta is nil or password is empty, can access + if meta == nil || meta.Password == "" { + return true + } + // if meta doesn't apply to sub_folder, can access + if !utils.PathEqual(meta.Path, reqPath) && !meta.PSub { + return true + } + // validate password + return meta.Password == password +} + +// ShouldProxy TODO need optimize +// when should be proxy? +// 1. config.MustProxy() +// 2. storage.WebProxy +// 3. proxy_types +func ShouldProxy(storage driver.Driver, filename string) bool { + if storage.Config().MustProxy() || storage.GetStorage().WebProxy { + return true + } + if utils.SliceContains(conf.SlicesMap[conf.ProxyTypes], utils.Ext(filename)) { + return true + } + return false +} diff --git a/server/common/check_test.go b/server/common/check_test.go new file mode 100644 index 0000000000000000000000000000000000000000..33114603be2024bf20013756efc0798e4b73352b --- /dev/null +++ b/server/common/check_test.go @@ -0,0 +1,24 @@ +package common + +import "testing" + +func TestIsApply(t *testing.T) { + datas := []struct { + metaPath string + reqPath string + applySub bool + result bool + }{ + { + metaPath: "/", + reqPath: "/test", + applySub: true, + result: true, + }, + } + for i, data := range datas { + if IsApply(data.metaPath, data.reqPath, data.applySub) != data.result { + t.Errorf("TestIsApply %d failed", i) + } + } +} diff --git a/server/common/common.go b/server/common/common.go new file mode 100644 index 0000000000000000000000000000000000000000..e231ffe6e88c71b836f295a79e9de94e3930e859 --- /dev/null +++ b/server/common/common.go @@ -0,0 +1,91 @@ +package common + +import ( + "context" + "net/http" + "strings" + + "github.com/alist-org/alist/v3/cmd/flags" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" +) + +func hidePrivacy(msg string) string { + for _, r := range conf.PrivacyReg { + msg = r.ReplaceAllStringFunc(msg, func(s string) string { + return strings.Repeat("*", len(s)) + }) + } + return msg +} + +// ErrorResp is used to return error response +// @param l: if true, log error +func ErrorResp(c *gin.Context, err error, code int, l ...bool) { + ErrorWithDataResp(c, err, code, nil, l...) + //if len(l) > 0 && l[0] { + // if flags.Debug || flags.Dev { + // log.Errorf("%+v", err) + // } else { + // log.Errorf("%v", err) + // } + //} + //c.JSON(200, Resp[interface{}]{ + // Code: code, + // Message: hidePrivacy(err.Error()), + // Data: nil, + //}) + //c.Abort() +} + +func ErrorWithDataResp(c *gin.Context, err error, code int, data interface{}, l ...bool) { + if len(l) > 0 && l[0] { + if flags.Debug || flags.Dev { + log.Errorf("%+v", err) + } else { + log.Errorf("%v", err) + } + } + c.JSON(200, Resp[interface{}]{ + Code: code, + Message: hidePrivacy(err.Error()), + Data: data, + }) + c.Abort() +} + +func ErrorStrResp(c *gin.Context, str string, code int, l ...bool) { + if len(l) != 0 && l[0] { + log.Error(str) + } + c.JSON(200, Resp[interface{}]{ + Code: code, + Message: hidePrivacy(str), + Data: nil, + }) + c.Abort() +} + +func SuccessResp(c *gin.Context, data ...interface{}) { + if len(data) == 0 { + c.JSON(200, Resp[interface{}]{ + Code: 200, + Message: "success", + Data: nil, + }) + return + } + c.JSON(200, Resp[interface{}]{ + Code: 200, + Message: "success", + Data: data[0], + }) +} + +func GetHttpReq(ctx context.Context) *http.Request { + if c, ok := ctx.(*gin.Context); ok { + return c.Request + } + return nil +} diff --git a/server/common/hide_privacy_test.go b/server/common/hide_privacy_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8cb0c63b9c0fd9cbb9414f0ef98bee5d99d96c52 --- /dev/null +++ b/server/common/hide_privacy_test.go @@ -0,0 +1,18 @@ +package common + +import ( + "regexp" + "testing" + + "github.com/alist-org/alist/v3/internal/conf" +) + +func TestHidePrivacy(t *testing.T) { + reg, err := regexp.Compile("(?U)access_token=(.*)&") + if err != nil { + t.Fatal(err) + } + conf.PrivacyReg = []*regexp.Regexp{reg} + res := hidePrivacy(`Get "https://pan.baidu.com/rest/2.0/xpan/file?access_token=121.d1f66e95acfa40274920079396a51c48.Y2aP2vQDq90hLBE3PAbVije59uTcn7GiWUfw8LCM_olw&dir=%2F&limit=200&method=list&order=name&start=0&web=web " : net/http: TLS handshake timeout`) + t.Log(res) +} diff --git a/server/common/proxy.go b/server/common/proxy.go new file mode 100644 index 0000000000000000000000000000000000000000..10923613edeb52142b46826ab4729f0c973be518 --- /dev/null +++ b/server/common/proxy.go @@ -0,0 +1,104 @@ +package common + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/net" + "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" + log "github.com/sirupsen/logrus" +) + +func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.Obj) error { + if link.MFile != nil { + defer link.MFile.Close() + attachFileName(w, file) + contentType := link.Header.Get("Content-Type") + if contentType != "" { + w.Header().Set("Content-Type", contentType) + } + http.ServeContent(w, r, file.GetName(), file.ModTime(), link.MFile) + return nil + } else if link.RangeReadCloser != nil { + attachFileName(w, file) + net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), link.RangeReadCloser.RangeRead) + defer func() { + _ = link.RangeReadCloser.Close() + }() + return nil + } else if link.Concurrency != 0 || link.PartSize != 0 { + attachFileName(w, file) + size := file.GetSize() + //var finalClosers model.Closers + finalClosers := utils.EmptyClosers() + header := net.ProcessHeader(r.Header, link.Header) + rangeReader := func(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) { + down := net.NewDownloader(func(d *net.Downloader) { + d.Concurrency = link.Concurrency + d.PartSize = link.PartSize + }) + req := &net.HttpRequestParams{ + URL: link.URL, + Range: httpRange, + Size: size, + HeaderRef: header, + } + rc, err := down.Download(ctx, req) + finalClosers.Add(rc) + return rc, err + } + net.ServeHTTP(w, r, file.GetName(), file.ModTime(), file.GetSize(), rangeReader) + defer finalClosers.Close() + return nil + } else { + //transparent proxy + header := net.ProcessHeader(r.Header, link.Header) + res, err := net.RequestHttp(context.Background(), r.Method, header, link.URL) + if err != nil { + return err + } + defer res.Body.Close() + + for h, v := range res.Header { + w.Header()[h] = v + } + w.WriteHeader(res.StatusCode) + if r.Method == http.MethodHead { + return nil + } + _, err = io.Copy(w, res.Body) + if err != nil { + return err + } + return nil + } +} +func attachFileName(w http.ResponseWriter, file model.Obj) { + fileName := file.GetName() + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, fileName, url.PathEscape(fileName))) + w.Header().Set("Content-Type", utils.GetMimeType(fileName)) +} + +var NoProxyRange = &model.RangeReadCloser{} + +func ProxyRange(link *model.Link, size int64) { + if link.MFile != nil { + return + } + if link.RangeReadCloser == nil { + var rrc, err = stream.GetRangeReadCloserFromLink(size, link) + if err != nil { + log.Warnf("ProxyRange error: %s", err) + return + } + link.RangeReadCloser = rrc + } else if link.RangeReadCloser == NoProxyRange { + link.RangeReadCloser = nil + } +} diff --git a/server/common/resp.go b/server/common/resp.go new file mode 100644 index 0000000000000000000000000000000000000000..53308339aa699b4d19fb7fb8f67d60f79c1724b9 --- /dev/null +++ b/server/common/resp.go @@ -0,0 +1,12 @@ +package common + +type Resp[T any] struct { + Code int `json:"code"` + Message string `json:"message"` + Data T `json:"data"` +} + +type PageResp struct { + Content interface{} `json:"content"` + Total int64 `json:"total"` +} diff --git a/server/common/sign.go b/server/common/sign.go new file mode 100644 index 0000000000000000000000000000000000000000..1d56ad208ac500a272f0f241c1965885f333ebcf --- /dev/null +++ b/server/common/sign.go @@ -0,0 +1,17 @@ +package common + +import ( + stdpath "path" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/internal/sign" +) + +func Sign(obj model.Obj, parent string, encrypt bool) string { + if obj.IsDir() || (!encrypt && !setting.GetBool(conf.SignAll)) { + return "" + } + return sign.Sign(stdpath.Join(parent, obj.GetName())) +} diff --git a/server/debug.go b/server/debug.go new file mode 100644 index 0000000000000000000000000000000000000000..081ef8c33815dd81c4b72ad31d55697dfb57a9ee --- /dev/null +++ b/server/debug.go @@ -0,0 +1,32 @@ +package server + +import ( + "net/http" + _ "net/http/pprof" + "runtime" + + "github.com/alist-org/alist/v3/server/common" + "github.com/alist-org/alist/v3/server/middlewares" + "github.com/gin-gonic/gin" +) + +func _pprof(g *gin.RouterGroup) { + g.Any("/*name", gin.WrapH(http.DefaultServeMux)) +} + +func debug(g *gin.RouterGroup) { + g.GET("/path/*path", middlewares.Down, func(ctx *gin.Context) { + rawPath := ctx.MustGet("path").(string) + ctx.JSON(200, gin.H{ + "path": rawPath, + }) + }) + g.GET("/hide_privacy", func(ctx *gin.Context) { + common.ErrorStrResp(ctx, "This is ip: 1.1.1.1", 400) + }) + g.GET("/gc", func(c *gin.Context) { + runtime.GC() + c.String(http.StatusOK, "ok") + }) + _pprof(g.Group("/pprof")) +} diff --git a/server/handles/auth.go b/server/handles/auth.go new file mode 100644 index 0000000000000000000000000000000000000000..209bdd3a2b80d9621b2680c17394e0a250326f68 --- /dev/null +++ b/server/handles/auth.go @@ -0,0 +1,183 @@ +package handles + +import ( + "bytes" + "encoding/base64" + "image/png" + "time" + + "github.com/Xhofe/go-cache" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + "github.com/pquerna/otp/totp" +) + +var loginCache = cache.NewMemCache[int]() +var ( + defaultDuration = time.Minute * 5 + defaultTimes = 5 +) + +type LoginReq struct { + Username string `json:"username" binding:"required"` + Password string `json:"password"` + OtpCode string `json:"otp_code"` +} + +// Login Deprecated +func Login(c *gin.Context) { + var req LoginReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + req.Password = model.StaticHash(req.Password) + loginHash(c, &req) +} + +// LoginHash login with password hashed by sha256 +func LoginHash(c *gin.Context) { + var req LoginReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + loginHash(c, &req) +} + +func loginHash(c *gin.Context, req *LoginReq) { + // check count of login + ip := c.ClientIP() + count, ok := loginCache.Get(ip) + if ok && count >= defaultTimes { + common.ErrorStrResp(c, "Too many unsuccessful sign-in attempts have been made using an incorrect username or password, Try again later.", 429) + loginCache.Expire(ip, defaultDuration) + return + } + // check username + user, err := op.GetUserByName(req.Username) + if err != nil { + common.ErrorResp(c, err, 400) + loginCache.Set(ip, count+1) + return + } + // validate password hash + if err := user.ValidatePwdStaticHash(req.Password); err != nil { + common.ErrorResp(c, err, 400) + loginCache.Set(ip, count+1) + return + } + // check 2FA + if user.OtpSecret != "" { + if !totp.Validate(req.OtpCode, user.OtpSecret) { + common.ErrorStrResp(c, "Invalid 2FA code", 402) + loginCache.Set(ip, count+1) + return + } + } + // generate token + token, err := common.GenerateToken(user) + if err != nil { + common.ErrorResp(c, err, 400, true) + return + } + common.SuccessResp(c, gin.H{"token": token}) + loginCache.Del(ip) +} + +type UserResp struct { + model.User + Otp bool `json:"otp"` +} + +// CurrentUser get current user by token +// if token is empty, return guest user +func CurrentUser(c *gin.Context) { + user := c.MustGet("user").(*model.User) + userResp := UserResp{ + User: *user, + } + userResp.Password = "" + if userResp.OtpSecret != "" { + userResp.Otp = true + } + common.SuccessResp(c, userResp) +} + +func UpdateCurrent(c *gin.Context) { + var req model.User + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + user := c.MustGet("user").(*model.User) + user.Username = req.Username + if req.Password != "" { + user.SetPassword(req.Password) + } + user.SsoID = req.SsoID + if err := op.UpdateUser(user); err != nil { + common.ErrorResp(c, err, 500) + } else { + common.SuccessResp(c) + } +} + +func Generate2FA(c *gin.Context) { + user := c.MustGet("user").(*model.User) + if user.IsGuest() { + common.ErrorStrResp(c, "Guest user can not generate 2FA code", 403) + return + } + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: "Alist", + AccountName: user.Username, + }) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + img, err := key.Image(400, 400) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + // to base64 + var buf bytes.Buffer + png.Encode(&buf, img) + b64 := base64.StdEncoding.EncodeToString(buf.Bytes()) + common.SuccessResp(c, gin.H{ + "qr": "data:image/png;base64," + b64, + "secret": key.Secret(), + }) +} + +type Verify2FAReq struct { + Code string `json:"code" binding:"required"` + Secret string `json:"secret" binding:"required"` +} + +func Verify2FA(c *gin.Context) { + var req Verify2FAReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + user := c.MustGet("user").(*model.User) + if user.IsGuest() { + common.ErrorStrResp(c, "Guest user can not generate 2FA code", 403) + return + } + if !totp.Validate(req.Code, req.Secret) { + common.ErrorStrResp(c, "Invalid 2FA code", 400) + return + } + user.OtpSecret = req.Secret + if err := op.UpdateUser(user); err != nil { + common.ErrorResp(c, err, 500) + } else { + common.SuccessResp(c) + } +} diff --git a/server/handles/down.go b/server/handles/down.go new file mode 100644 index 0000000000000000000000000000000000000000..d3d41e85a2beaa645616198643a86ee148668a66 --- /dev/null +++ b/server/handles/down.go @@ -0,0 +1,138 @@ +package handles + +import ( + "fmt" + "io" + stdpath "path" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/internal/sign" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" +) + +func Down(c *gin.Context) { + rawPath := c.MustGet("path").(string) + filename := stdpath.Base(rawPath) + storage, err := fs.GetStorage(rawPath, &fs.GetStoragesArgs{}) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + if common.ShouldProxy(storage, filename) { + Proxy(c) + return + } else { + link, _, err := fs.Link(c, rawPath, model.LinkArgs{ + IP: c.ClientIP(), + Header: c.Request.Header, + Type: c.Query("type"), + HttpReq: c.Request, + }) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + if link.MFile != nil { + defer func(ReadSeekCloser io.ReadCloser) { + err := ReadSeekCloser.Close() + if err != nil { + log.Errorf("close data error: %s", err) + } + }(link.MFile) + } + c.Header("Referrer-Policy", "no-referrer") + c.Header("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate") + if setting.GetBool(conf.ForwardDirectLinkParams) { + query := c.Request.URL.Query() + for _, v := range conf.SlicesMap[conf.IgnoreDirectLinkParams] { + query.Del(v) + } + link.URL, err = utils.InjectQuery(link.URL, query) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + } + c.Redirect(302, link.URL) + } +} + +func Proxy(c *gin.Context) { + rawPath := c.MustGet("path").(string) + filename := stdpath.Base(rawPath) + storage, err := fs.GetStorage(rawPath, &fs.GetStoragesArgs{}) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + if canProxy(storage, filename) { + downProxyUrl := storage.GetStorage().DownProxyUrl + if downProxyUrl != "" { + _, ok := c.GetQuery("d") + if !ok { + URL := fmt.Sprintf("%s%s?sign=%s", + strings.Split(downProxyUrl, "\n")[0], + utils.EncodePath(rawPath, true), + sign.Sign(rawPath)) + c.Redirect(302, URL) + return + } + } + link, file, err := fs.Link(c, rawPath, model.LinkArgs{ + Header: c.Request.Header, + Type: c.Query("type"), + HttpReq: c.Request, + }) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + if link.URL != "" && setting.GetBool(conf.ForwardDirectLinkParams) { + query := c.Request.URL.Query() + for _, v := range conf.SlicesMap[conf.IgnoreDirectLinkParams] { + query.Del(v) + } + link.URL, err = utils.InjectQuery(link.URL, query) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + } + err = common.Proxy(c.Writer, c.Request, link, file) + if err != nil { + common.ErrorResp(c, err, 500, true) + return + } + } else { + common.ErrorStrResp(c, "proxy not allowed", 403) + return + } +} + +// TODO need optimize +// when can be proxy? +// 1. text file +// 2. config.MustProxy() +// 3. storage.WebProxy +// 4. proxy_types +// solution: text_file + shouldProxy() +func canProxy(storage driver.Driver, filename string) bool { + if storage.Config().MustProxy() || storage.GetStorage().WebProxy || storage.GetStorage().WebdavProxy() { + return true + } + if utils.SliceContains(conf.SlicesMap[conf.ProxyTypes], utils.Ext(filename)) { + return true + } + if utils.SliceContains(conf.SlicesMap[conf.TextTypes], utils.Ext(filename)) { + return true + } + return false +} diff --git a/server/handles/driver.go b/server/handles/driver.go new file mode 100644 index 0000000000000000000000000000000000000000..10c27ce8b61b6b05e81e9398c4f8ddf7ee909ca9 --- /dev/null +++ b/server/handles/driver.go @@ -0,0 +1,28 @@ +package handles + +import ( + "fmt" + + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" +) + +func ListDriverInfo(c *gin.Context) { + common.SuccessResp(c, op.GetDriverInfoMap()) +} + +func ListDriverNames(c *gin.Context) { + common.SuccessResp(c, op.GetDriverNames()) +} + +func GetDriverInfo(c *gin.Context) { + driverName := c.Query("driver") + infoMap := op.GetDriverInfoMap() + items, ok := infoMap[driverName] + if !ok { + common.ErrorStrResp(c, fmt.Sprintf("driver [%s] not found", driverName), 404) + return + } + common.SuccessResp(c, items) +} diff --git a/server/handles/fsbatch.go b/server/handles/fsbatch.go new file mode 100644 index 0000000000000000000000000000000000000000..fa7971dfbe1ed2640f82b48872bf88d0d0825cc6 --- /dev/null +++ b/server/handles/fsbatch.go @@ -0,0 +1,211 @@ +package handles + +import ( + "fmt" + "regexp" + + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/generic" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +type BatchRenameReq struct { + SrcDir string `json:"src_dir"` + RenameObjects []struct { + SrcName string `json:"src_name"` + NewName string `json:"new_name"` + } `json:"rename_objects"` +} + +func FsBatchRename(c *gin.Context) { + var req BatchRenameReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + user := c.MustGet("user").(*model.User) + if !user.CanRename() { + common.ErrorResp(c, errs.PermissionDenied, 403) + return + } + + reqPath, err := user.JoinPath(req.SrcDir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + + meta, err := op.GetNearestMeta(reqPath) + if err != nil { + if !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500, true) + return + } + } + c.Set("meta", meta) + for _, renameObject := range req.RenameObjects { + if renameObject.SrcName == "" || renameObject.NewName == "" { + continue + } + filePath := fmt.Sprintf("%s/%s", reqPath, renameObject.SrcName) + if err := fs.Rename(c, filePath, renameObject.NewName); err != nil { + common.ErrorResp(c, err, 500) + return + } + } + common.SuccessResp(c) +} + +type RecursiveMoveReq struct { + SrcDir string `json:"src_dir"` + DstDir string `json:"dst_dir"` +} + +func FsRecursiveMove(c *gin.Context) { + var req RecursiveMoveReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + + user := c.MustGet("user").(*model.User) + if !user.CanMove() { + common.ErrorResp(c, errs.PermissionDenied, 403) + return + } + srcDir, err := user.JoinPath(req.SrcDir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + dstDir, err := user.JoinPath(req.DstDir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + + meta, err := op.GetNearestMeta(srcDir) + if err != nil { + if !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500, true) + return + } + } + c.Set("meta", meta) + + rootFiles, err := fs.List(c, srcDir, &fs.ListArgs{}) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + + // record the file path + filePathMap := make(map[model.Obj]string) + movingFiles := generic.NewQueue[model.Obj]() + for _, file := range rootFiles { + movingFiles.Push(file) + filePathMap[file] = srcDir + } + + for !movingFiles.IsEmpty() { + + movingFile := movingFiles.Pop() + movingFilePath := filePathMap[movingFile] + movingFileName := fmt.Sprintf("%s/%s", movingFilePath, movingFile.GetName()) + if movingFile.IsDir() { + // directory, recursive move + subFilePath := movingFileName + subFiles, err := fs.List(c, movingFileName, &fs.ListArgs{Refresh: true}) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + for _, subFile := range subFiles { + movingFiles.Push(subFile) + filePathMap[subFile] = subFilePath + } + } else { + + if movingFilePath == dstDir { + // same directory, don't move + continue + } + + // move + err := fs.Move(c, movingFileName, dstDir, movingFiles.IsEmpty()) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + } + + } + + common.SuccessResp(c) +} + +type RegexRenameReq struct { + SrcDir string `json:"src_dir"` + SrcNameRegex string `json:"src_name_regex"` + NewNameRegex string `json:"new_name_regex"` +} + +func FsRegexRename(c *gin.Context) { + var req RegexRenameReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + user := c.MustGet("user").(*model.User) + if !user.CanRename() { + common.ErrorResp(c, errs.PermissionDenied, 403) + return + } + + reqPath, err := user.JoinPath(req.SrcDir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + + meta, err := op.GetNearestMeta(reqPath) + if err != nil { + if !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500, true) + return + } + } + c.Set("meta", meta) + + srcRegexp, err := regexp.Compile(req.SrcNameRegex) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + + files, err := fs.List(c, reqPath, &fs.ListArgs{}) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + + for _, file := range files { + + if srcRegexp.MatchString(file.GetName()) { + filePath := fmt.Sprintf("%s/%s", reqPath, file.GetName()) + newFileName := srcRegexp.ReplaceAllString(file.GetName(), req.NewNameRegex) + if err := fs.Rename(c, filePath, newFileName); err != nil { + common.ErrorResp(c, err, 500) + return + } + } + + } + + common.SuccessResp(c) +} diff --git a/server/handles/fsmanage.go b/server/handles/fsmanage.go new file mode 100644 index 0000000000000000000000000000000000000000..c9be7f2d52ed44aec0e54b6225162a18488715a9 --- /dev/null +++ b/server/handles/fsmanage.go @@ -0,0 +1,398 @@ +package handles + +import ( + "fmt" + "io" + stdpath "path" + + "github.com/alist-org/alist/v3/pkg/tache" + + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/sign" + "github.com/alist-org/alist/v3/pkg/generic" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" +) + +type MkdirOrLinkReq struct { + Path string `json:"path" form:"path"` +} + +func FsMkdir(c *gin.Context) { + var req MkdirOrLinkReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + user := c.MustGet("user").(*model.User) + reqPath, err := user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + if !user.CanWrite() { + meta, err := op.GetNearestMeta(stdpath.Dir(reqPath)) + if err != nil { + if !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500, true) + return + } + } + if !common.CanWrite(meta, reqPath) { + common.ErrorResp(c, errs.PermissionDenied, 403) + return + } + } + if err := fs.MakeDir(c, reqPath); err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c) +} + +type MoveCopyReq struct { + SrcDir string `json:"src_dir"` + DstDir string `json:"dst_dir"` + Override bool `json:"override"` + Names []string `json:"names"` +} + +func FsMove(c *gin.Context) { + var req MoveCopyReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + if len(req.Names) == 0 { + common.ErrorStrResp(c, "Empty file names", 400) + return + } + user := c.MustGet("user").(*model.User) + if !user.CanMove() { + common.ErrorResp(c, errs.PermissionDenied, 403) + return + } + srcDir, err := user.JoinPath(req.SrcDir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + dstDir, err := user.JoinPath(req.DstDir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + for i, name := range req.Names { + err := fs.Move(c, stdpath.Join(srcDir, name), dstDir, len(req.Names) > i+1) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + } + common.SuccessResp(c) +} + +func FsCopy(c *gin.Context) { + var req MoveCopyReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + if len(req.Names) == 0 { + common.ErrorStrResp(c, "Empty file names", 400) + return + } + user := c.MustGet("user").(*model.User) + if !user.CanCopy() { + common.ErrorResp(c, errs.PermissionDenied, 403) + return + } + srcDir, err := user.JoinPath(req.SrcDir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + dstDir, err := user.JoinPath(req.DstDir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + var addedTasks []tache.TaskWithInfo + for i, name := range req.Names { + t, err := fs.Copy(c, stdpath.Join(srcDir, name), dstDir, req.Override, len(req.Names) > i+1) + if t != nil { + addedTasks = append(addedTasks, t) + } + if err != nil { + common.ErrorResp(c, err, 500) + return + } + } + common.SuccessResp(c, gin.H{ + "tasks": getTaskInfos(addedTasks), + }) +} + +// 相比于FsCopy FsCopyItem可以给每个文件指定dstDir更灵活点 +// SrcFile文件路径 文件的绝对路径 +// DstDir 要复制到的目录的绝对路径 +type CopyItem struct { + SrcFile string `json:"src_file"` + DstDir string `json:"dst_dir"` +} +type CopyItemReq struct { + Override bool `json:"override"` + Names []CopyItem `json:"names"` +} + +func FsCopyItem(c *gin.Context) { + var req CopyItemReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + if len(req.Names) == 0 { + common.ErrorStrResp(c, "Empty file names", 400) + return + } + user := c.MustGet("user").(*model.User) + if !user.CanCopy() { + common.ErrorResp(c, errs.PermissionDenied, 403) + return + } + // srcDir, err := user.JoinPath(req.SrcDir) + // if err != nil { + // common.ErrorResp(c, err, 403) + // return + // } + // dstDir, err := user.JoinPath(req.DstDir) + // if err != nil { + // common.ErrorResp(c, err, 403) + // return + // } + var addedTasks []tache.TaskWithInfo + for i, name := range req.Names { + t, err := fs.Copy(c, name.SrcFile, name.DstDir, req.Override, len(req.Names) > i+1) + if t != nil { + addedTasks = append(addedTasks, t) + } + if err != nil { + common.ErrorResp(c, err, 500) + return + } + } + common.SuccessResp(c, gin.H{ + "tasks": getTaskInfos(addedTasks), + }) +} + +type RenameReq struct { + Path string `json:"path"` + Name string `json:"name"` +} + +func FsRename(c *gin.Context) { + var req RenameReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + user := c.MustGet("user").(*model.User) + if !user.CanRename() { + common.ErrorResp(c, errs.PermissionDenied, 403) + return + } + reqPath, err := user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + if err := fs.Rename(c, reqPath, req.Name); err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c) +} + +type RemoveReq struct { + Dir string `json:"dir"` + Names []string `json:"names"` +} + +func FsRemove(c *gin.Context) { + var req RemoveReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + if len(req.Names) == 0 { + common.ErrorStrResp(c, "Empty file names", 400) + return + } + user := c.MustGet("user").(*model.User) + if !user.CanRemove() { + common.ErrorResp(c, errs.PermissionDenied, 403) + return + } + reqDir, err := user.JoinPath(req.Dir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + for _, name := range req.Names { + err := fs.Remove(c, stdpath.Join(reqDir, name)) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + } + //fs.ClearCache(req.Dir) + common.SuccessResp(c) +} + +type RemoveEmptyDirectoryReq struct { + SrcDir string `json:"src_dir"` +} + +func FsRemoveEmptyDirectory(c *gin.Context) { + var req RemoveEmptyDirectoryReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + + user := c.MustGet("user").(*model.User) + if !user.CanRemove() { + common.ErrorResp(c, errs.PermissionDenied, 403) + return + } + srcDir, err := user.JoinPath(req.SrcDir) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + + meta, err := op.GetNearestMeta(srcDir) + if err != nil { + if !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500, true) + return + } + } + c.Set("meta", meta) + + rootFiles, err := fs.List(c, srcDir, &fs.ListArgs{}) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + + // record the file path + filePathMap := make(map[model.Obj]string) + // record the parent file + fileParentMap := make(map[model.Obj]model.Obj) + // removing files + removingFiles := generic.NewQueue[model.Obj]() + // removed files + removedFiles := make(map[string]bool) + for _, file := range rootFiles { + if !file.IsDir() { + continue + } + removingFiles.Push(file) + filePathMap[file] = srcDir + } + + for !removingFiles.IsEmpty() { + + removingFile := removingFiles.Pop() + removingFilePath := fmt.Sprintf("%s/%s", filePathMap[removingFile], removingFile.GetName()) + + if removedFiles[removingFilePath] { + continue + } + + subFiles, err := fs.List(c, removingFilePath, &fs.ListArgs{Refresh: true}) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + + if len(subFiles) == 0 { + // remove empty directory + err = fs.Remove(c, removingFilePath) + removedFiles[removingFilePath] = true + if err != nil { + common.ErrorResp(c, err, 500) + return + } + // recheck parent folder + parentFile, exist := fileParentMap[removingFile] + if exist { + removingFiles.Push(parentFile) + } + + } else { + // recursive remove + for _, subFile := range subFiles { + if !subFile.IsDir() { + continue + } + removingFiles.Push(subFile) + filePathMap[subFile] = removingFilePath + fileParentMap[subFile] = removingFile + } + } + + } + + common.SuccessResp(c) +} + +// Link return real link, just for proxy program, it may contain cookie, so just allowed for admin +func Link(c *gin.Context) { + var req MkdirOrLinkReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + //user := c.MustGet("user").(*model.User) + //rawPath := stdpath.Join(user.BasePath, req.Path) + // why need not join base_path? because it's always the full path + rawPath := req.Path + storage, err := fs.GetStorage(rawPath, &fs.GetStoragesArgs{}) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + if storage.Config().OnlyLocal { + common.SuccessResp(c, model.Link{ + URL: fmt.Sprintf("%s/p%s?d&sign=%s", + common.GetApiUrl(c.Request), + utils.EncodePath(rawPath, true), + sign.Sign(rawPath)), + }) + return + } + link, _, err := fs.Link(c, rawPath, model.LinkArgs{IP: c.ClientIP(), Header: c.Request.Header, HttpReq: c.Request}) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + if link.MFile != nil { + defer func(ReadSeekCloser io.ReadCloser) { + err := ReadSeekCloser.Close() + if err != nil { + log.Errorf("close link data error: %v", err) + } + }(link.MFile) + } + common.SuccessResp(c, link) + return +} diff --git a/server/handles/fsread.go b/server/handles/fsread.go new file mode 100644 index 0000000000000000000000000000000000000000..7c580f635e4c5231a40e1e14486cf714ef81caa7 --- /dev/null +++ b/server/handles/fsread.go @@ -0,0 +1,397 @@ +package handles + +import ( + "fmt" + stdpath "path" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/internal/sign" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +type ListReq struct { + model.PageReq + Path string `json:"path" form:"path"` + Password string `json:"password" form:"password"` + Refresh bool `json:"refresh"` +} + +type DirReq struct { + Path string `json:"path" form:"path"` + Password string `json:"password" form:"password"` + ForceRoot bool `json:"force_root" form:"force_root"` +} + +type ObjResp struct { + Name string `json:"name"` + Size int64 `json:"size"` + IsDir bool `json:"is_dir"` + Modified time.Time `json:"modified"` + Created time.Time `json:"created"` + Sign string `json:"sign"` + Thumb string `json:"thumb"` + Type int `json:"type"` + HashInfoStr string `json:"hashinfo"` + HashInfo map[*utils.HashType]string `json:"hash_info"` +} + +type FsListResp struct { + Content []ObjResp `json:"content"` + Total int64 `json:"total"` + Readme string `json:"readme"` + Header string `json:"header"` + Write bool `json:"write"` + Provider string `json:"provider"` +} + +func FsList(c *gin.Context) { + var req ListReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + req.Validate() + user := c.MustGet("user").(*model.User) + reqPath, err := user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + meta, err := op.GetNearestMeta(reqPath) + if err != nil { + if !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500, true) + return + } + } + c.Set("meta", meta) + if !common.CanAccess(user, meta, reqPath, req.Password) { + common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) + return + } + if !user.CanWrite() && !common.CanWrite(meta, reqPath) && req.Refresh { + common.ErrorStrResp(c, "Refresh without permission", 403) + return + } + objs, err := fs.List(c, reqPath, &fs.ListArgs{Refresh: req.Refresh}) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + total, objs := pagination(objs, &req.PageReq) + provider := "unknown" + storage, err := fs.GetStorage(reqPath, &fs.GetStoragesArgs{}) + if err == nil { + provider = storage.GetStorage().Driver + } + common.SuccessResp(c, FsListResp{ + Content: toObjsResp(objs, reqPath, isEncrypt(meta, reqPath)), + Total: int64(total), + Readme: getReadme(meta, reqPath), + Header: getHeader(meta, reqPath), + Write: user.CanWrite() || common.CanWrite(meta, reqPath), + Provider: provider, + }) +} + +func FsDirs(c *gin.Context) { + var req DirReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + user := c.MustGet("user").(*model.User) + reqPath := req.Path + if req.ForceRoot { + if !user.IsAdmin() { + common.ErrorStrResp(c, "Permission denied", 403) + return + } + } else { + tmp, err := user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + reqPath = tmp + } + meta, err := op.GetNearestMeta(reqPath) + if err != nil { + if !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500, true) + return + } + } + c.Set("meta", meta) + if !common.CanAccess(user, meta, reqPath, req.Password) { + common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) + return + } + objs, err := fs.List(c, reqPath, &fs.ListArgs{}) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + dirs := filterDirs(objs) + common.SuccessResp(c, dirs) +} + +type DirResp struct { + Name string `json:"name"` + Modified time.Time `json:"modified"` +} + +func filterDirs(objs []model.Obj) []DirResp { + var dirs []DirResp + for _, obj := range objs { + if obj.IsDir() { + dirs = append(dirs, DirResp{ + Name: obj.GetName(), + Modified: obj.ModTime(), + }) + } + } + return dirs +} + +func getReadme(meta *model.Meta, path string) string { + if meta != nil && (utils.PathEqual(meta.Path, path) || meta.RSub) { + return meta.Readme + } + return "" +} + +func getHeader(meta *model.Meta, path string) string { + if meta != nil && (utils.PathEqual(meta.Path, path) || meta.HeaderSub) { + return meta.Header + } + return "" +} + +func isEncrypt(meta *model.Meta, path string) bool { + if common.IsStorageSignEnabled(path) { + return true + } + if meta == nil || meta.Password == "" { + return false + } + if !utils.PathEqual(meta.Path, path) && !meta.PSub { + return false + } + return true +} + +func pagination(objs []model.Obj, req *model.PageReq) (int, []model.Obj) { + pageIndex, pageSize := req.Page, req.PerPage + total := len(objs) + start := (pageIndex - 1) * pageSize + if start > total { + return total, []model.Obj{} + } + end := start + pageSize + if end > total { + end = total + } + return total, objs[start:end] +} + +func toObjsResp(objs []model.Obj, parent string, encrypt bool) []ObjResp { + var resp []ObjResp + for _, obj := range objs { + thumb, _ := model.GetThumb(obj) + resp = append(resp, ObjResp{ + Name: obj.GetName(), + Size: obj.GetSize(), + IsDir: obj.IsDir(), + Modified: obj.ModTime(), + Created: obj.CreateTime(), + HashInfoStr: obj.GetHash().String(), + HashInfo: obj.GetHash().Export(), + Sign: common.Sign(obj, parent, encrypt), + Thumb: thumb, + Type: utils.GetObjType(obj.GetName(), obj.IsDir()), + }) + } + return resp +} + +type FsGetReq struct { + Path string `json:"path" form:"path"` + Password string `json:"password" form:"password"` +} + +type FsGetResp struct { + ObjResp + RawURL string `json:"raw_url"` + Readme string `json:"readme"` + Header string `json:"header"` + Provider string `json:"provider"` + Related []ObjResp `json:"related"` +} + +func FsGet(c *gin.Context) { + var req FsGetReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + user := c.MustGet("user").(*model.User) + reqPath, err := user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + meta, err := op.GetNearestMeta(reqPath) + if err != nil { + if !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500) + return + } + } + c.Set("meta", meta) + if !common.CanAccess(user, meta, reqPath, req.Password) { + common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) + return + } + obj, err := fs.Get(c, reqPath, &fs.GetArgs{}) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + var rawURL string + + storage, err := fs.GetStorage(reqPath, &fs.GetStoragesArgs{}) + provider := "unknown" + if err == nil { + provider = storage.Config().Name + } + if !obj.IsDir() { + if err != nil { + common.ErrorResp(c, err, 500) + return + } + if storage.Config().MustProxy() || storage.GetStorage().WebProxy { + query := "" + if isEncrypt(meta, reqPath) || setting.GetBool(conf.SignAll) { + query = "?sign=" + sign.Sign(reqPath) + } + if storage.GetStorage().DownProxyUrl != "" { + rawURL = fmt.Sprintf("%s%s?sign=%s", + strings.Split(storage.GetStorage().DownProxyUrl, "\n")[0], + utils.EncodePath(reqPath, true), + sign.Sign(reqPath)) + } else { + rawURL = fmt.Sprintf("%s/p%s%s", + common.GetApiUrl(c.Request), + utils.EncodePath(reqPath, true), + query) + } + } else { + // file have raw url + if url, ok := model.GetUrl(obj); ok { + rawURL = url + } else { + // if storage is not proxy, use raw url by fs.Link + link, _, err := fs.Link(c, reqPath, model.LinkArgs{ + IP: c.ClientIP(), + Header: c.Request.Header, + HttpReq: c.Request, + }) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + rawURL = link.URL + } + } + } + var related []model.Obj + parentPath := stdpath.Dir(reqPath) + sameLevelFiles, err := fs.List(c, parentPath, &fs.ListArgs{}) + if err == nil { + related = filterRelated(sameLevelFiles, obj) + } + parentMeta, _ := op.GetNearestMeta(parentPath) + thumb, _ := model.GetThumb(obj) + common.SuccessResp(c, FsGetResp{ + ObjResp: ObjResp{ + Name: obj.GetName(), + Size: obj.GetSize(), + IsDir: obj.IsDir(), + Modified: obj.ModTime(), + Created: obj.CreateTime(), + HashInfoStr: obj.GetHash().String(), + HashInfo: obj.GetHash().Export(), + Sign: common.Sign(obj, parentPath, isEncrypt(meta, reqPath)), + Type: utils.GetFileType(obj.GetName()), + Thumb: thumb, + }, + RawURL: rawURL, + Readme: getReadme(meta, reqPath), + Header: getHeader(meta, reqPath), + Provider: provider, + Related: toObjsResp(related, parentPath, isEncrypt(parentMeta, parentPath)), + }) +} + +func filterRelated(objs []model.Obj, obj model.Obj) []model.Obj { + var related []model.Obj + nameWithoutExt := strings.TrimSuffix(obj.GetName(), stdpath.Ext(obj.GetName())) + for _, o := range objs { + if o.GetName() == obj.GetName() { + continue + } + if strings.HasPrefix(o.GetName(), nameWithoutExt) { + related = append(related, o) + } + } + return related +} + +type FsOtherReq struct { + model.FsOtherArgs + Password string `json:"password" form:"password"` +} + +func FsOther(c *gin.Context) { + var req FsOtherReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + user := c.MustGet("user").(*model.User) + var err error + req.Path, err = user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + meta, err := op.GetNearestMeta(req.Path) + if err != nil { + if !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500) + return + } + } + c.Set("meta", meta) + if !common.CanAccess(user, meta, req.Path, req.Password) { + common.ErrorStrResp(c, "password is incorrect or you have no permission", 403) + return + } + res, err := fs.Other(c, req.FsOtherArgs) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c, res) +} diff --git a/server/handles/fsup.go b/server/handles/fsup.go new file mode 100644 index 0000000000000000000000000000000000000000..4b4e69c7282a27cbc3bef7e79783cb178708d39e --- /dev/null +++ b/server/handles/fsup.go @@ -0,0 +1,152 @@ +package handles + +import ( + "io" + "net/url" + stdpath "path" + "strconv" + "time" + + "github.com/alist-org/alist/v3/pkg/tache" + + "github.com/alist-org/alist/v3/internal/stream" + + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" +) + +func getLastModified(c *gin.Context) time.Time { + now := time.Now() + lastModifiedStr := c.GetHeader("Last-Modified") + lastModifiedMillisecond, err := strconv.ParseInt(lastModifiedStr, 10, 64) + if err != nil { + return now + } + lastModified := time.UnixMilli(lastModifiedMillisecond) + return lastModified +} + +func FsStream(c *gin.Context) { + path := c.GetHeader("File-Path") + path, err := url.PathUnescape(path) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + asTask := c.GetHeader("As-Task") == "true" + user := c.MustGet("user").(*model.User) + path, err = user.JoinPath(path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + dir, name := stdpath.Split(path) + sizeStr := c.GetHeader("Content-Length") + size, err := strconv.ParseInt(sizeStr, 10, 64) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + s := &stream.FileStream{ + Obj: &model.Object{ + Name: name, + Size: size, + Modified: getLastModified(c), + }, + Reader: c.Request.Body, + Mimetype: c.GetHeader("Content-Type"), + WebPutAsTask: asTask, + } + var t tache.TaskWithInfo + if asTask { + t, err = fs.PutAsTask(dir, s) + } else { + err = fs.PutDirectly(c, dir, s, true) + } + defer c.Request.Body.Close() + if err != nil { + common.ErrorResp(c, err, 500) + return + } + if t == nil { + common.SuccessResp(c) + return + } + common.SuccessResp(c, gin.H{ + "task": getTaskInfo(t), + }) +} + +func FsForm(c *gin.Context) { + path := c.GetHeader("File-Path") + path, err := url.PathUnescape(path) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + asTask := c.GetHeader("As-Task") == "true" + user := c.MustGet("user").(*model.User) + path, err = user.JoinPath(path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + storage, err := fs.GetStorage(path, &fs.GetStoragesArgs{}) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if storage.Config().NoUpload { + common.ErrorStrResp(c, "Current storage doesn't support upload", 405) + return + } + file, err := c.FormFile("file") + if err != nil { + common.ErrorResp(c, err, 500) + return + } + f, err := file.Open() + if err != nil { + common.ErrorResp(c, err, 500) + return + } + defer f.Close() + dir, name := stdpath.Split(path) + s := stream.FileStream{ + Obj: &model.Object{ + Name: name, + Size: file.Size, + Modified: getLastModified(c), + }, + Reader: f, + Mimetype: file.Header.Get("Content-Type"), + WebPutAsTask: asTask, + } + var t tache.TaskWithInfo + if asTask { + s.Reader = struct { + io.Reader + }{f} + t, err = fs.PutAsTask(dir, &s) + } else { + ss, err := stream.NewSeekableStream(s, nil) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + err = fs.PutDirectly(c, dir, ss, true) + } + if err != nil { + common.ErrorResp(c, err, 500) + return + } + if t == nil { + common.SuccessResp(c) + return + } + common.SuccessResp(c, gin.H{ + "task": getTaskInfo(t), + }) +} diff --git a/server/handles/helper.go b/server/handles/helper.go new file mode 100644 index 0000000000000000000000000000000000000000..bd41c42c3bfb93aa07bacc2f24b1aa67740c4571 --- /dev/null +++ b/server/handles/helper.go @@ -0,0 +1,100 @@ +package handles + +import ( + "fmt" + "net/url" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" +) + +func Favicon(c *gin.Context) { + c.Redirect(302, setting.GetStr(conf.Favicon)) +} + +func Robots(c *gin.Context) { + c.String(200, setting.GetStr(conf.RobotsTxt)) +} + +func Plist(c *gin.Context) { + linkNameB64 := strings.TrimSuffix(c.Param("link_name"), ".plist") + linkName, err := utils.SafeAtob(linkNameB64) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + linkNameSplit := strings.Split(linkName, "/") + if len(linkNameSplit) != 2 { + common.ErrorStrResp(c, "malformed link", 400) + return + } + linkEncode := linkNameSplit[0] + linkStr, err := url.PathUnescape(linkEncode) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + link, err := url.Parse(linkStr) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + fullName := c.Param("name") + Url := link.String() + Url = strings.ReplaceAll(Url, "<", "[") + Url = strings.ReplaceAll(Url, ">", "]") + nameEncode := linkNameSplit[1] + fullName, err = url.PathUnescape(nameEncode) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + name := fullName + identifier := fmt.Sprintf("ci.nn.%s", url.PathEscape(fullName)) + sep := "@" + if strings.Contains(fullName, sep) { + ss := strings.Split(fullName, sep) + name = strings.Join(ss[:len(ss)-1], sep) + identifier = ss[len(ss)-1] + } + + name = strings.ReplaceAll(name, "<", "[") + name = strings.ReplaceAll(name, ">", "]") + plist := fmt.Sprintf(` + + + items + + + assets + + + kind + software-package + url + + + + metadata + + bundle-identifier + %s + bundle-version + 4.4 + kind + software + title + %s + + + + +`, Url, identifier, name) + c.Header("Content-Type", "application/xml;charset=utf-8") + c.Status(200) + _, _ = c.Writer.WriteString(plist) +} diff --git a/server/handles/index.go b/server/handles/index.go new file mode 100644 index 0000000000000000000000000000000000000000..0fa1fa0e9bf08edeb289f4a9ff9b881b52a81b92 --- /dev/null +++ b/server/handles/index.go @@ -0,0 +1,105 @@ +package handles + +import ( + "context" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/search" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" +) + +type UpdateIndexReq struct { + Paths []string `json:"paths"` + MaxDepth int `json:"max_depth"` + //IgnorePaths []string `json:"ignore_paths"` +} + +func BuildIndex(c *gin.Context) { + if search.Running.Load() { + common.ErrorStrResp(c, "index is running", 400) + return + } + go func() { + ctx := context.Background() + err := search.Clear(ctx) + if err != nil { + log.Errorf("clear index error: %+v", err) + return + } + err = search.BuildIndex(context.Background(), []string{"/"}, + conf.SlicesMap[conf.IgnorePaths], setting.GetInt(conf.MaxIndexDepth, 20), true) + if err != nil { + log.Errorf("build index error: %+v", err) + } + }() + common.SuccessResp(c) +} + +func UpdateIndex(c *gin.Context) { + var req UpdateIndexReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + if search.Running.Load() { + common.ErrorStrResp(c, "index is running", 400) + return + } + if !search.Config(c).AutoUpdate { + common.ErrorStrResp(c, "update is not supported for current index", 400) + return + } + go func() { + ctx := context.Background() + for _, path := range req.Paths { + err := search.Del(ctx, path) + if err != nil { + log.Errorf("delete index on %s error: %+v", path, err) + return + } + } + err := search.BuildIndex(context.Background(), req.Paths, + conf.SlicesMap[conf.IgnorePaths], req.MaxDepth, false) + if err != nil { + log.Errorf("update index error: %+v", err) + } + }() + common.SuccessResp(c) +} + +func StopIndex(c *gin.Context) { + if !search.Running.Load() { + common.ErrorStrResp(c, "index is not running", 400) + return + } + search.Quit <- struct{}{} + common.SuccessResp(c) +} + +func ClearIndex(c *gin.Context) { + if search.Running.Load() { + common.ErrorStrResp(c, "index is running", 400) + return + } + search.Clear(c) + search.WriteProgress(&model.IndexProgress{ + ObjCount: 0, + IsDone: true, + LastDoneTime: nil, + Error: "", + }) + common.SuccessResp(c) +} + +func GetProgress(c *gin.Context) { + progress, err := search.Progress() + if err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c, progress) +} diff --git a/server/handles/ldap_login.go b/server/handles/ldap_login.go new file mode 100644 index 0000000000000000000000000000000000000000..cf3148291b10c2b5955332bb0da90ac86d06338f --- /dev/null +++ b/server/handles/ldap_login.go @@ -0,0 +1,157 @@ +package handles + +import ( + "crypto/tls" + "errors" + "fmt" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/pkg/utils/random" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + "gopkg.in/ldap.v3" +) + +func LoginLdap(c *gin.Context) { + var req LoginReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + loginLdap(c, &req) +} + +func loginLdap(c *gin.Context, req *LoginReq) { + enabled := setting.GetBool(conf.LdapLoginEnabled) + if !enabled { + common.ErrorStrResp(c, "ldap is not enabled", 403) + return + } + + // check count of login + ip := c.ClientIP() + count, ok := loginCache.Get(ip) + if ok && count >= defaultTimes { + common.ErrorStrResp(c, "Too many unsuccessful sign-in attempts have been made using an incorrect username or password, Try again later.", 429) + loginCache.Expire(ip, defaultDuration) + return + } + + // Auth start + ldapServer := setting.GetStr(conf.LdapServer) + ldapManagerDN := setting.GetStr(conf.LdapManagerDN) + ldapManagerPassword := setting.GetStr(conf.LdapManagerPassword) + ldapUserSearchBase := setting.GetStr(conf.LdapUserSearchBase) + ldapUserSearchFilter := setting.GetStr(conf.LdapUserSearchFilter) // (uid=%s) + + // Connect to LdapServer + l, err := dial(ldapServer) + if err != nil { + utils.Log.Errorf("failed to connect to LDAP: %v", err) + common.ErrorResp(c, err, 500) + return + } + + // First bind with a read only user + if ldapManagerDN != "" && ldapManagerPassword != "" { + err = l.Bind(ldapManagerDN, ldapManagerPassword) + if err != nil { + utils.Log.Errorf("Failed to bind to LDAP: %v", err) + common.ErrorResp(c, err, 500) + return + } + } + + // Search for the given username + searchRequest := ldap.NewSearchRequest( + ldapUserSearchBase, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + fmt.Sprintf(ldapUserSearchFilter, req.Username), + []string{"dn"}, + nil, + ) + sr, err := l.Search(searchRequest) + if err != nil { + utils.Log.Errorf("LDAP search failed: %v", err) + common.ErrorResp(c, err, 500) + return + } + if len(sr.Entries) != 1 { + utils.Log.Errorf("User does not exist or too many entries returned") + common.ErrorResp(c, err, 500) + return + } + userDN := sr.Entries[0].DN + + // Bind as the user to verify their password + err = l.Bind(userDN, req.Password) + if err != nil { + utils.Log.Errorf("Failed to auth. %v", err) + common.ErrorResp(c, err, 400) + loginCache.Set(ip, count+1) + return + } else { + utils.Log.Infof("Auth successful username:%s", req.Username) + } + // Auth finished + + user, err := op.GetUserByName(req.Username) + if err != nil { + user, err = ladpRegister(req.Username) + if err != nil { + common.ErrorResp(c, err, 400) + loginCache.Set(ip, count+1) + return + } + } + + // generate token + token, err := common.GenerateToken(user) + if err != nil { + common.ErrorResp(c, err, 400, true) + return + } + common.SuccessResp(c, gin.H{"token": token}) + loginCache.Del(ip) +} + +func ladpRegister(username string) (*model.User, error) { + if username == "" { + return nil, errors.New("cannot get username from ldap provider") + } + user := &model.User{ + ID: 0, + Username: username, + Password: random.String(16), + Permission: int32(setting.GetInt(conf.LdapDefaultPermission, 0)), + BasePath: setting.GetStr(conf.LdapDefaultDir), + Role: 0, + Disabled: false, + } + if err := db.CreateUser(user); err != nil { + return nil, err + } + return user, nil +} + +func dial(ldapServer string) (*ldap.Conn, error) { + var tlsEnabled bool = false + if strings.HasPrefix(ldapServer, "ldaps://") { + tlsEnabled = true + ldapServer = strings.TrimPrefix(ldapServer, "ldaps://") + } else if strings.HasPrefix(ldapServer, "ldap://") { + ldapServer = strings.TrimPrefix(ldapServer, "ldap://") + } + + if tlsEnabled { + return ldap.DialTLS("tcp", ldapServer, &tls.Config{InsecureSkipVerify: true}) + } else { + return ldap.Dial("tcp", ldapServer) + } +} diff --git a/server/handles/meta.go b/server/handles/meta.go new file mode 100644 index 0000000000000000000000000000000000000000..00aa31371922f8eb0f42ebecaba59d9cb6bfb017 --- /dev/null +++ b/server/handles/meta.go @@ -0,0 +1,109 @@ +package handles + +import ( + "fmt" + "strconv" + "strings" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/server/common" + "github.com/dlclark/regexp2" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" +) + +func ListMetas(c *gin.Context) { + var req model.PageReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + req.Validate() + log.Debugf("%+v", req) + metas, total, err := op.GetMetas(req.Page, req.PerPage) + if err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c, common.PageResp{ + Content: metas, + Total: total, + }) +} + +func CreateMeta(c *gin.Context) { + var req model.Meta + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + r, err := validHide(req.Hide) + if err != nil { + common.ErrorStrResp(c, fmt.Sprintf("%s is illegal: %s", r, err.Error()), 400) + return + } + if err := op.CreateMeta(&req); err != nil { + common.ErrorResp(c, err, 500, true) + } else { + common.SuccessResp(c) + } +} + +func UpdateMeta(c *gin.Context) { + var req model.Meta + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + r, err := validHide(req.Hide) + if err != nil { + common.ErrorStrResp(c, fmt.Sprintf("%s is illegal: %s", r, err.Error()), 400) + return + } + if err := op.UpdateMeta(&req); err != nil { + common.ErrorResp(c, err, 500, true) + } else { + common.SuccessResp(c) + } +} + +func validHide(hide string) (string, error) { + rs := strings.Split(hide, "\n") + for _, r := range rs { + _, err := regexp2.Compile(r, regexp2.None) + if err != nil { + return r, err + } + } + return "", nil +} + +func DeleteMeta(c *gin.Context) { + idStr := c.Query("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if err := op.DeleteMetaById(uint(id)); err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c) +} + +func GetMeta(c *gin.Context) { + idStr := c.Query("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + meta, err := op.GetMetaById(uint(id)) + if err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c, meta) +} diff --git a/server/handles/offline_download.go b/server/handles/offline_download.go new file mode 100644 index 0000000000000000000000000000000000000000..8f224ea2a607f4fd3361fc9a8c60d6408976757b --- /dev/null +++ b/server/handles/offline_download.go @@ -0,0 +1,118 @@ +package handles + +import ( + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/offline_download/tool" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/pkg/tache" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" +) + +type SetAria2Req struct { + Uri string `json:"uri" form:"uri"` + Secret string `json:"secret" form:"secret"` +} + +func SetAria2(c *gin.Context) { + var req SetAria2Req + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + items := []model.SettingItem{ + {Key: conf.Aria2Uri, Value: req.Uri, Type: conf.TypeString, Group: model.OFFLINE_DOWNLOAD, Flag: model.PRIVATE}, + {Key: conf.Aria2Secret, Value: req.Secret, Type: conf.TypeString, Group: model.OFFLINE_DOWNLOAD, Flag: model.PRIVATE}, + } + if err := op.SaveSettingItems(items); err != nil { + common.ErrorResp(c, err, 500) + return + } + _tool, err := tool.Tools.Get("aria2") + version, err := _tool.Init() + if err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c, version) +} + +type SetQbittorrentReq struct { + Url string `json:"url" form:"url"` + Seedtime string `json:"seedtime" form:"seedtime"` +} + +func SetQbittorrent(c *gin.Context) { + var req SetQbittorrentReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + items := []model.SettingItem{ + {Key: conf.QbittorrentUrl, Value: req.Url, Type: conf.TypeString, Group: model.OFFLINE_DOWNLOAD, Flag: model.PRIVATE}, + {Key: conf.QbittorrentSeedtime, Value: req.Seedtime, Type: conf.TypeNumber, Group: model.OFFLINE_DOWNLOAD, Flag: model.PRIVATE}, + } + if err := op.SaveSettingItems(items); err != nil { + common.ErrorResp(c, err, 500) + return + } + _tool, err := tool.Tools.Get("qBittorrent") + if err != nil { + common.ErrorResp(c, err, 500) + return + } + if _, err := _tool.Init(); err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c, "ok") +} + +func OfflineDownloadTools(c *gin.Context) { + tools := tool.Tools.Names() + common.SuccessResp(c, tools) +} + +type AddOfflineDownloadReq struct { + Urls []string `json:"urls"` + Path string `json:"path"` + Tool string `json:"tool"` + DeletePolicy string `json:"delete_policy"` +} + +func AddOfflineDownload(c *gin.Context) { + user := c.MustGet("user").(*model.User) + if !user.CanAddOfflineDownloadTasks() { + common.ErrorStrResp(c, "permission denied", 403) + return + } + + var req AddOfflineDownloadReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + reqPath, err := user.JoinPath(req.Path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + var tasks []tache.TaskWithInfo + for _, url := range req.Urls { + t, err := tool.AddURL(c, &tool.AddURLArgs{ + URL: url, + DstDirPath: reqPath, + Tool: req.Tool, + DeletePolicy: tool.DeletePolicy(req.DeletePolicy), + }) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + tasks = append(tasks, t) + } + common.SuccessResp(c, gin.H{ + "tasks": getTaskInfos(tasks), + }) +} diff --git a/server/handles/search.go b/server/handles/search.go new file mode 100644 index 0000000000000000000000000000000000000000..8881731bd606310c220a077bcf7dbea928162c62 --- /dev/null +++ b/server/handles/search.go @@ -0,0 +1,76 @@ +package handles + +import ( + "path" + "strings" + + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/search" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +type SearchReq struct { + model.SearchReq + Password string `json:"password"` +} + +type SearchResp struct { + model.SearchNode + Type int `json:"type"` +} + +func Search(c *gin.Context) { + var ( + req SearchReq + err error + ) + if err = c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + user := c.MustGet("user").(*model.User) + req.Parent, err = user.JoinPath(req.Parent) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if err := req.Validate(); err != nil { + common.ErrorResp(c, err, 400) + return + } + nodes, total, err := search.Search(c, req.SearchReq) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + var filteredNodes []model.SearchNode + for _, node := range nodes { + if !strings.HasPrefix(node.Parent, user.BasePath) { + continue + } + meta, err := op.GetNearestMeta(node.Parent) + if err != nil && !errors.Is(errors.Cause(err), errs.MetaNotFound) { + continue + } + if !common.CanAccess(user, meta, path.Join(node.Parent, node.Name), req.Password) { + continue + } + filteredNodes = append(filteredNodes, node) + } + common.SuccessResp(c, common.PageResp{ + Content: utils.MustSliceConvert(filteredNodes, nodeToSearchResp), + Total: total, + }) +} + +func nodeToSearchResp(node model.SearchNode) SearchResp { + return SearchResp{ + SearchNode: node, + Type: utils.GetObjType(node.Name, node.IsDir), + } +} diff --git a/server/handles/setting.go b/server/handles/setting.go new file mode 100644 index 0000000000000000000000000000000000000000..0454c7aa6e0a56f3ab20ccd6d78cb364a16b03e0 --- /dev/null +++ b/server/handles/setting.go @@ -0,0 +1,112 @@ +package handles + +import ( + "strconv" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/internal/sign" + "github.com/alist-org/alist/v3/pkg/utils/random" + "github.com/alist-org/alist/v3/server/common" + "github.com/alist-org/alist/v3/server/static" + "github.com/gin-gonic/gin" +) + +func ResetToken(c *gin.Context) { + token := random.Token() + item := model.SettingItem{Key: "token", Value: token, Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE} + if err := op.SaveSettingItem(&item); err != nil { + common.ErrorResp(c, err, 500) + return + } + sign.Instance() + common.SuccessResp(c, token) +} + +func GetSetting(c *gin.Context) { + key := c.Query("key") + keys := c.Query("keys") + if key != "" { + item, err := op.GetSettingItemByKey(key) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + common.SuccessResp(c, item) + } else { + items, err := op.GetSettingItemInKeys(strings.Split(keys, ",")) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + common.SuccessResp(c, items) + } +} + +func SaveSettings(c *gin.Context) { + var req []model.SettingItem + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + if err := op.SaveSettingItems(req); err != nil { + common.ErrorResp(c, err, 500) + } else { + if req[0].Group == 10 { + title := setting.GetStr(conf.SiteTitle) + if setting.GetBool(conf.NotifyEnabled) { + go op.Notify(title+"测试通知", "欢迎使用!!!!") + } + } + common.SuccessResp(c) + static.UpdateIndex() + } +} + +func ListSettings(c *gin.Context) { + groupStr := c.Query("group") + groupsStr := c.Query("groups") + var settings []model.SettingItem + var err error + if groupsStr == "" && groupStr == "" { + settings, err = op.GetSettingItems() + } else { + var groupStrings []string + if groupsStr != "" { + groupStrings = strings.Split(groupsStr, ",") + } else { + groupStrings = append(groupStrings, groupStr) + } + var groups []int + for _, str := range groupStrings { + group, err := strconv.Atoi(str) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + groups = append(groups, group) + } + settings, err = op.GetSettingItemsInGroups(groups) + } + if err != nil { + common.ErrorResp(c, err, 400) + return + } + common.SuccessResp(c, settings) +} + +func DeleteSetting(c *gin.Context) { + key := c.Query("key") + if err := op.DeleteSettingItemByKey(key); err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c) +} + +func PublicSettings(c *gin.Context) { + common.SuccessResp(c, op.GetPublicSettingsMap()) +} diff --git a/server/handles/ssologin.go b/server/handles/ssologin.go new file mode 100644 index 0000000000000000000000000000000000000000..70298a9c3f0a64fca713ca71f25f7d82e0304e48 --- /dev/null +++ b/server/handles/ssologin.go @@ -0,0 +1,446 @@ +package handles + +import ( + "encoding/base32" + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/url" + "path" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/pkg/utils/random" + "github.com/alist-org/alist/v3/server/common" + "github.com/coreos/go-oidc" + "github.com/gin-gonic/gin" + "github.com/go-resty/resty/v2" + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" + "golang.org/x/oauth2" + "gorm.io/gorm" +) + +var opts = totp.ValidateOpts{ + // state verify won't expire in 30 secs, which is quite enough for the callback + Period: 30, + Skew: 1, + // in some OIDC providers(such as Authelia), state parameter must be at least 8 characters + Digits: otp.DigitsEight, + Algorithm: otp.AlgorithmSHA1, +} + +func SSOLoginRedirect(c *gin.Context) { + method := c.Query("method") + usecompatibility := setting.GetBool(conf.SSOCompatibilityMode) + enabled := setting.GetBool(conf.SSOLoginEnabled) + clientId := setting.GetStr(conf.SSOClientId) + platform := setting.GetStr(conf.SSOLoginPlatform) + var r_url string + var redirect_uri string + if !enabled { + common.ErrorStrResp(c, "Single sign-on is not enabled", 403) + return + } + urlValues := url.Values{} + if method == "" { + common.ErrorStrResp(c, "no method provided", 400) + return + } + if usecompatibility { + redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + method + } else { + redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + method + } + urlValues.Add("response_type", "code") + urlValues.Add("redirect_uri", redirect_uri) + urlValues.Add("client_id", clientId) + switch platform { + case "Github": + r_url = "https://github.com/login/oauth/authorize?" + urlValues.Add("scope", "read:user") + case "Microsoft": + r_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?" + urlValues.Add("scope", "user.read") + urlValues.Add("response_mode", "query") + case "Google": + r_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlValues.Add("scope", "https://www.googleapis.com/auth/userinfo.profile") + case "Dingtalk": + r_url = "https://login.dingtalk.com/oauth2/auth?" + urlValues.Add("scope", "openid") + urlValues.Add("prompt", "consent") + urlValues.Add("response_type", "code") + case "Casdoor": + endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/") + r_url = endpoint + "/login/oauth/authorize?" + urlValues.Add("scope", "profile") + urlValues.Add("state", endpoint) + case "OIDC": + oauth2Config, err := GetOIDCClient(c) + if err != nil { + common.ErrorStrResp(c, err.Error(), 400) + return + } + // generate state parameter + state, err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts) + if err != nil { + common.ErrorStrResp(c, err.Error(), 400) + return + } + c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL(state)) + return + default: + common.ErrorStrResp(c, "invalid platform", 400) + return + } + c.Redirect(302, r_url+urlValues.Encode()) +} + +var ssoClient = resty.New().SetRetryCount(3) + +func GetOIDCClient(c *gin.Context) (*oauth2.Config, error) { + var redirect_uri string + usecompatibility := setting.GetBool(conf.SSOCompatibilityMode) + argument := c.Query("method") + if usecompatibility { + argument = path.Base(c.Request.URL.Path) + } + if usecompatibility { + redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + argument + } else { + redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument + } + endpoint := setting.GetStr(conf.SSOEndpointName) + provider, err := oidc.NewProvider(c, endpoint) + if err != nil { + return nil, err + } + clientId := setting.GetStr(conf.SSOClientId) + clientSecret := setting.GetStr(conf.SSOClientSecret) + return &oauth2.Config{ + ClientID: clientId, + ClientSecret: clientSecret, + RedirectURL: redirect_uri, + + // Discovery returns the OAuth2 endpoints. + Endpoint: provider.Endpoint(), + + // "openid" is a required scope for OpenID Connect flows. + Scopes: []string{oidc.ScopeOpenID, "profile"}, + }, nil +} + +func autoRegister(username, userID string, err error) (*model.User, error) { + if !errors.Is(err, gorm.ErrRecordNotFound) || !setting.GetBool(conf.SSOAutoRegister) { + return nil, err + } + if username == "" { + return nil, errors.New("cannot get username from SSO provider") + } + user := &model.User{ + ID: 0, + Username: username, + Password: random.String(16), + Permission: int32(setting.GetInt(conf.SSODefaultPermission, 0)), + BasePath: setting.GetStr(conf.SSODefaultDir), + Role: 0, + Disabled: false, + SsoID: userID, + } + if err = db.CreateUser(user); err != nil { + if strings.HasPrefix(err.Error(), "UNIQUE constraint failed") && strings.HasSuffix(err.Error(), "username") { + user.Username = user.Username + "_" + userID + if err = db.CreateUser(user); err != nil { + return nil, err + } + } else { + return nil, err + } + } + return user, nil +} + +func parseJWT(p string) ([]byte, error) { + parts := strings.Split(p, ".") + if len(parts) < 2 { + return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts)) + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err) + } + return payload, nil +} + +func OIDCLoginCallback(c *gin.Context) { + useCompatibility := setting.GetBool(conf.SSOCompatibilityMode) + argument := c.Query("method") + if useCompatibility { + argument = path.Base(c.Request.URL.Path) + } + clientId := setting.GetStr(conf.SSOClientId) + endpoint := setting.GetStr(conf.SSOEndpointName) + provider, err := oidc.NewProvider(c, endpoint) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + oauth2Config, err := GetOIDCClient(c) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + // add state verify process + stateVerification, err := totp.ValidateCustom(c.Query("state"), base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if !stateVerification { + common.ErrorStrResp(c, "incorrect or expired state parameter", 400) + return + } + + oauth2Token, err := oauth2Config.Exchange(c, c.Query("code")) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + common.ErrorStrResp(c, "no id_token found in oauth2 token", 400) + return + } + verifier := provider.Verifier(&oidc.Config{ + ClientID: clientId, + }) + _, err = verifier.Verify(c, rawIDToken) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + payload, err := parseJWT(rawIDToken) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + userID := utils.Json.Get(payload, setting.GetStr(conf.SSOOIDCUsernameKey, "name")).ToString() + if userID == "" { + common.ErrorStrResp(c, "cannot get username from OIDC provider", 400) + return + } + if argument == "get_sso_id" { + if useCompatibility { + c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID) + return + } + html := fmt.Sprintf(` + + + + `, userID) + c.Data(200, "text/html; charset=utf-8", []byte(html)) + return + } + if argument == "sso_get_token" { + user, err := db.GetUserBySSOID(userID) + if err != nil { + user, err = autoRegister(userID, userID, err) + if err != nil { + common.ErrorResp(c, err, 400) + } + } + token, err := common.GenerateToken(user) + if err != nil { + common.ErrorResp(c, err, 400) + } + if useCompatibility { + c.Redirect(302, common.GetApiUrl(c.Request)+"/@login?token="+token) + return + } + html := fmt.Sprintf(` + + + + `, token) + c.Data(200, "text/html; charset=utf-8", []byte(html)) + return + } +} + +func SSOLoginCallback(c *gin.Context) { + enabled := setting.GetBool(conf.SSOLoginEnabled) + usecompatibility := setting.GetBool(conf.SSOCompatibilityMode) + if !enabled { + common.ErrorResp(c, errors.New("sso login is disabled"), 500) + return + } + argument := c.Query("method") + if usecompatibility { + argument = path.Base(c.Request.URL.Path) + } + if !utils.SliceContains([]string{"get_sso_id", "sso_get_token"}, argument) { + common.ErrorResp(c, errors.New("invalid request"), 500) + return + } + clientId := setting.GetStr(conf.SSOClientId) + platform := setting.GetStr(conf.SSOLoginPlatform) + clientSecret := setting.GetStr(conf.SSOClientSecret) + var tokenUrl, userUrl, scope, authField, idField, usernameField string + additionalForm := make(map[string]string) + switch platform { + case "Github": + tokenUrl = "https://github.com/login/oauth/access_token" + userUrl = "https://api.github.com/user" + authField = "code" + scope = "read:user" + idField = "id" + usernameField = "login" + case "Microsoft": + tokenUrl = "https://login.microsoftonline.com/common/oauth2/v2.0/token" + userUrl = "https://graph.microsoft.com/v1.0/me" + additionalForm["grant_type"] = "authorization_code" + scope = "user.read" + authField = "code" + idField = "id" + usernameField = "displayName" + case "Google": + tokenUrl = "https://oauth2.googleapis.com/token" + userUrl = "https://www.googleapis.com/oauth2/v1/userinfo" + additionalForm["grant_type"] = "authorization_code" + scope = "https://www.googleapis.com/auth/userinfo.profile" + authField = "code" + idField = "id" + usernameField = "name" + case "Dingtalk": + tokenUrl = "https://api.dingtalk.com/v1.0/oauth2/userAccessToken" + userUrl = "https://api.dingtalk.com/v1.0/contact/users/me" + authField = "authCode" + idField = "unionId" + usernameField = "nick" + case "Casdoor": + endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/") + tokenUrl = endpoint + "/api/login/oauth/access_token" + userUrl = endpoint + "/api/userinfo" + additionalForm["grant_type"] = "authorization_code" + scope = "profile" + authField = "code" + idField = "sub" + usernameField = "preferred_username" + case "OIDC": + OIDCLoginCallback(c) + return + default: + common.ErrorStrResp(c, "invalid platform", 400) + return + } + callbackCode := c.Query(authField) + if callbackCode == "" { + common.ErrorStrResp(c, "No code provided", 400) + return + } + var resp *resty.Response + var err error + if platform == "Dingtalk" { + resp, err = ssoClient.R().SetHeader("content-type", "application/json").SetHeader("Accept", "application/json"). + SetBody(map[string]string{ + "clientId": clientId, + "clientSecret": clientSecret, + "code": callbackCode, + "grantType": "authorization_code", + }). + Post(tokenUrl) + } else { + var redirect_uri string + if usecompatibility { + redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + argument + } else { + redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument + } + resp, err = ssoClient.R().SetHeader("Accept", "application/json"). + SetFormData(map[string]string{ + "client_id": clientId, + "client_secret": clientSecret, + "code": callbackCode, + "redirect_uri": redirect_uri, + "scope": scope, + }).SetFormData(additionalForm).Post(tokenUrl) + } + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if platform == "Dingtalk" { + accessToken := utils.Json.Get(resp.Body(), "accessToken").ToString() + resp, err = ssoClient.R().SetHeader("x-acs-dingtalk-access-token", accessToken). + Get(userUrl) + } else { + accessToken := utils.Json.Get(resp.Body(), "access_token").ToString() + resp, err = ssoClient.R().SetHeader("Authorization", "Bearer "+accessToken). + Get(userUrl) + } + if err != nil { + common.ErrorResp(c, err, 400) + return + } + userID := utils.Json.Get(resp.Body(), idField).ToString() + if utils.SliceContains([]string{"", "0"}, userID) { + common.ErrorResp(c, errors.New("error occurred"), 400) + return + } + if argument == "get_sso_id" { + if usecompatibility { + c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID) + return + } + html := fmt.Sprintf(` + + + + `, userID) + c.Data(200, "text/html; charset=utf-8", []byte(html)) + return + } + username := utils.Json.Get(resp.Body(), usernameField).ToString() + user, err := db.GetUserBySSOID(userID) + if err != nil { + user, err = autoRegister(username, userID, err) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + } + token, err := common.GenerateToken(user) + if err != nil { + common.ErrorResp(c, err, 400) + } + if usecompatibility { + c.Redirect(302, common.GetApiUrl(c.Request)+"/@login?token="+token) + return + } + html := fmt.Sprintf(` + + + + `, token) + c.Data(200, "text/html; charset=utf-8", []byte(html)) +} diff --git a/server/handles/storage.go b/server/handles/storage.go new file mode 100644 index 0000000000000000000000000000000000000000..813dc23e6304b16a3078888238a65876d25d2245 --- /dev/null +++ b/server/handles/storage.go @@ -0,0 +1,166 @@ +package handles + +import ( + "context" + "strconv" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" +) + +func ListStorages(c *gin.Context) { + var req model.PageReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + req.Validate() + log.Debugf("%+v", req) + storages, total, err := db.GetStorages(req.Page, req.PerPage) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c, common.PageResp{ + Content: storages, + Total: total, + }) +} + +func CreateStorage(c *gin.Context) { + var req model.Storage + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + if id, err := op.CreateStorage(c, req); err != nil { + common.ErrorWithDataResp(c, err, 500, gin.H{ + "id": id, + }, true) + } else { + common.SuccessResp(c, gin.H{ + "id": id, + }) + } +} + +func CopyStorage(c *gin.Context) { + idStr := c.Query("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if _, err := op.CopyStorageById(c, uint(id)); err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c) +} + +func UpdateStorage(c *gin.Context) { + var req model.Storage + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + if err := op.UpdateStorage(c, req); err != nil { + common.ErrorResp(c, err, 500, true) + } else { + common.SuccessResp(c) + } +} + +func DeleteStorage(c *gin.Context) { + idStr := c.Query("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if err := op.DeleteStorageById(c, uint(id)); err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c) +} + +func DisableStorage(c *gin.Context) { + idStr := c.Query("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if err := op.DisableStorage(c, uint(id)); err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c) +} + +func EnableStorage(c *gin.Context) { + idStr := c.Query("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if err := op.EnableStorage(c, uint(id)); err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c) +} + +func GetStorage(c *gin.Context) { + idStr := c.Query("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + storage, err := db.GetStorageById(uint(id)) + if err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c, storage) +} + +func LoadAllStorages(c *gin.Context) { + storages, err := db.GetEnabledStorages() + if err != nil { + log.Errorf("failed get enabled storages: %+v", err) + common.ErrorResp(c, err, 500, true) + return + } + conf.StoragesLoaded = false + go func(storages []model.Storage) { + for _, storage := range storages { + storageDriver, err := op.GetStorageByMountPath(storage.MountPath) + if err != nil { + log.Errorf("failed get storage driver: %+v", err) + continue + } + // drop the storage in the driver + if err := storageDriver.Drop(context.Background()); err != nil { + log.Errorf("failed drop storage: %+v", err) + continue + } + if err := op.LoadStorage(context.Background(), storage); err != nil { + log.Errorf("failed get enabled storages: %+v", err) + continue + } + log.Infof("success load storage: [%s], driver: [%s]", + storage.MountPath, storage.Driver) + } + conf.StoragesLoaded = true + }(storages) + common.SuccessResp(c) +} diff --git a/server/handles/task.go b/server/handles/task.go new file mode 100644 index 0000000000000000000000000000000000000000..b9ae657cec49a1034a37466ae573d7199156e7c4 --- /dev/null +++ b/server/handles/task.go @@ -0,0 +1,100 @@ +package handles + +import ( + "math" + + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/offline_download/tool" + "github.com/alist-org/alist/v3/pkg/tache" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" +) + +type TaskInfo struct { + ID string `json:"id"` + Name string `json:"name"` + State tache.State `json:"state"` + Status string `json:"status"` + Progress float64 `json:"progress"` + Size int64 `json:"size"` + Error string `json:"error"` +} + +func getTaskInfo[T tache.TaskWithInfo](task T) TaskInfo { + errMsg := "" + if task.GetErr() != nil { + errMsg = task.GetErr().Error() + } + progress := task.GetProgress() + // if progress is NaN, set it to 100 + if math.IsNaN(progress) { + progress = 100 + } + return TaskInfo{ + ID: task.GetID(), + Name: task.GetName(), + State: task.GetState(), + Status: task.GetStatus(), + Size: task.GetSize(), + Progress: progress, + Error: errMsg, + } +} + +func getTaskInfos[T tache.TaskWithInfo](tasks []T) []TaskInfo { + return utils.MustSliceConvert(tasks, getTaskInfo[T]) +} + +func taskRoute[T tache.TaskWithInfo](g *gin.RouterGroup, manager *tache.Manager[T]) { + g.GET("/undone", func(c *gin.Context) { + common.SuccessResp(c, getTaskInfos(manager.GetByState(tache.StatePending, tache.StateRunning, + tache.StateCanceling, tache.StateErrored, tache.StateFailing, tache.StateWaitingRetry, tache.StateBeforeRetry))) + }) + g.GET("/done", func(c *gin.Context) { + common.SuccessResp(c, getTaskInfos(manager.GetByState(tache.StateCanceled, tache.StateFailed, tache.StateSucceeded))) + }) + g.POST("/info", func(c *gin.Context) { + tid := c.Query("tid") + task, ok := manager.GetByID(tid) + if !ok { + common.ErrorStrResp(c, "task not found", 404) + return + } + common.SuccessResp(c, getTaskInfo(task)) + }) + g.POST("/cancel", func(c *gin.Context) { + tid := c.Query("tid") + manager.Cancel(tid) + common.SuccessResp(c) + }) + g.POST("/delete", func(c *gin.Context) { + tid := c.Query("tid") + manager.Remove(tid) + common.SuccessResp(c) + }) + g.POST("/retry", func(c *gin.Context) { + tid := c.Query("tid") + manager.Retry(tid) + common.SuccessResp(c) + }) + g.POST("/clear_done", func(c *gin.Context) { + manager.RemoveByState(tache.StateCanceled, tache.StateFailed, tache.StateSucceeded) + common.SuccessResp(c) + }) + g.POST("/clear_succeeded", func(c *gin.Context) { + manager.RemoveByState(tache.StateSucceeded) + common.SuccessResp(c) + }) + g.POST("/retry_failed", func(c *gin.Context) { + manager.RetryAllFailed() + common.SuccessResp(c) + }) +} + +func SetupTaskRoute(g *gin.RouterGroup) { + taskRoute(g.Group("/upload"), fs.UploadTaskManager) + taskRoute(g.Group("/copy"), fs.CopyTaskManager) + taskRoute(g.Group("/offline_download"), tool.DownloadTaskManager) + taskRoute(g.Group("/offline_download_transfer"), tool.TransferTaskManager) +} diff --git a/server/handles/user.go b/server/handles/user.go new file mode 100644 index 0000000000000000000000000000000000000000..4d404a4c6528dc8f427d918efa33459468f5bedc --- /dev/null +++ b/server/handles/user.go @@ -0,0 +1,139 @@ +package handles + +import ( + "strconv" + + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" +) + +func ListUsers(c *gin.Context) { + var req model.PageReq + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + req.Validate() + log.Debugf("%+v", req) + users, total, err := op.GetUsers(req.Page, req.PerPage) + if err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c, common.PageResp{ + Content: users, + Total: total, + }) +} + +func CreateUser(c *gin.Context) { + var req model.User + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + if req.IsAdmin() || req.IsGuest() { + common.ErrorStrResp(c, "admin or guest user can not be created", 400, true) + return + } + req.SetPassword(req.Password) + req.Password = "" + req.Authn = "[]" + if err := op.CreateUser(&req); err != nil { + common.ErrorResp(c, err, 500, true) + } else { + common.SuccessResp(c) + } +} + +func UpdateUser(c *gin.Context) { + var req model.User + if err := c.ShouldBind(&req); err != nil { + common.ErrorResp(c, err, 400) + return + } + user, err := op.GetUserById(req.ID) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + if user.Role != req.Role { + common.ErrorStrResp(c, "role can not be changed", 400) + return + } + if req.Password == "" { + req.PwdHash = user.PwdHash + req.Salt = user.Salt + } else { + req.SetPassword(req.Password) + req.Password = "" + } + if req.OtpSecret == "" { + req.OtpSecret = user.OtpSecret + } + if req.Disabled && req.IsAdmin() { + common.ErrorStrResp(c, "admin user can not be disabled", 400) + return + } + if err := op.UpdateUser(&req); err != nil { + common.ErrorResp(c, err, 500) + } else { + common.SuccessResp(c) + } +} + +func DeleteUser(c *gin.Context) { + idStr := c.Query("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if err := op.DeleteUserById(uint(id)); err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c) +} + +func GetUser(c *gin.Context) { + idStr := c.Query("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + user, err := op.GetUserById(uint(id)) + if err != nil { + common.ErrorResp(c, err, 500, true) + return + } + common.SuccessResp(c, user) +} + +func Cancel2FAById(c *gin.Context) { + idStr := c.Query("id") + id, err := strconv.Atoi(idStr) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + if err := op.Cancel2FAById(uint(id)); err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c) +} + +func DelUserCache(c *gin.Context) { + username := c.Query("username") + err := op.DelUserCache(username) + if err != nil { + common.ErrorResp(c, err, 500) + return + } + common.SuccessResp(c) +} diff --git a/server/handles/webauthn.go b/server/handles/webauthn.go new file mode 100644 index 0000000000000000000000000000000000000000..1bd1884ef119c545f79a4e357b213a189fdcc682 --- /dev/null +++ b/server/handles/webauthn.go @@ -0,0 +1,234 @@ +package handles + +import ( + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + + "github.com/alist-org/alist/v3/internal/authn" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/db" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + "github.com/go-webauthn/webauthn/protocol" + "github.com/go-webauthn/webauthn/webauthn" +) + +func BeginAuthnLogin(c *gin.Context) { + enabled := setting.GetBool(conf.WebauthnLoginEnabled) + if !enabled { + common.ErrorStrResp(c, "WebAuthn is not enabled", 403) + return + } + authnInstance, err := authn.NewAuthnInstance(c.Request) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + + var ( + options *protocol.CredentialAssertion + sessionData *webauthn.SessionData + ) + if username := c.Query("username"); username != "" { + var user *model.User + user, err = db.GetUserByName(username) + if err == nil { + options, sessionData, err = authnInstance.BeginLogin(user) + } + } else { // client-side discoverable login + options, sessionData, err = authnInstance.BeginDiscoverableLogin() + } + if err != nil { + common.ErrorResp(c, err, 400) + return + } + + val, err := json.Marshal(sessionData) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + common.SuccessResp(c, gin.H{ + "options": options, + "session": val, + }) +} + +func FinishAuthnLogin(c *gin.Context) { + enabled := setting.GetBool(conf.WebauthnLoginEnabled) + if !enabled { + common.ErrorStrResp(c, "WebAuthn is not enabled", 403) + return + } + authnInstance, err := authn.NewAuthnInstance(c.Request) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + + sessionDataString := c.GetHeader("session") + sessionDataBytes, err := base64.StdEncoding.DecodeString(sessionDataString) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + + var sessionData webauthn.SessionData + if err := json.Unmarshal(sessionDataBytes, &sessionData); err != nil { + common.ErrorResp(c, err, 400) + return + } + + var user *model.User + if username := c.Query("username"); username != "" { + user, err = db.GetUserByName(username) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + _, err = authnInstance.FinishLogin(user, sessionData, c.Request) + } else { // client-side discoverable login + _, err = authnInstance.FinishDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) { + // first param `rawID` in this callback function is equal to ID in webauthn.Credential, + // but it's unnnecessary to check it. + // userHandle param is equal to (User).WebAuthnID(). + userID := uint(binary.LittleEndian.Uint64(userHandle)) + user, err = db.GetUserById(userID) + if err != nil { + return nil, err + } + + return user, nil + }, sessionData, c.Request) + } + if err != nil { + common.ErrorResp(c, err, 400) + return + } + + token, err := common.GenerateToken(user) + if err != nil { + common.ErrorResp(c, err, 400, true) + return + } + common.SuccessResp(c, gin.H{"token": token}) +} + +func BeginAuthnRegistration(c *gin.Context) { + enabled := setting.GetBool(conf.WebauthnLoginEnabled) + if !enabled { + common.ErrorStrResp(c, "WebAuthn is not enabled", 403) + return + } + user := c.MustGet("user").(*model.User) + + authnInstance, err := authn.NewAuthnInstance(c.Request) + if err != nil { + common.ErrorResp(c, err, 400) + } + + options, sessionData, err := authnInstance.BeginRegistration(user) + + if err != nil { + common.ErrorResp(c, err, 400) + } + + val, err := json.Marshal(sessionData) + if err != nil { + common.ErrorResp(c, err, 400) + } + + common.SuccessResp(c, gin.H{ + "options": options, + "session": val, + }) +} + +func FinishAuthnRegistration(c *gin.Context) { + enabled := setting.GetBool(conf.WebauthnLoginEnabled) + if !enabled { + common.ErrorStrResp(c, "WebAuthn is not enabled", 403) + return + } + user := c.MustGet("user").(*model.User) + sessionDataString := c.GetHeader("Session") + + authnInstance, err := authn.NewAuthnInstance(c.Request) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + + sessionDataBytes, err := base64.StdEncoding.DecodeString(sessionDataString) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + + var sessionData webauthn.SessionData + if err := json.Unmarshal(sessionDataBytes, &sessionData); err != nil { + common.ErrorResp(c, err, 400) + return + } + + credential, err := authnInstance.FinishRegistration(user, sessionData, c.Request) + + if err != nil { + common.ErrorResp(c, err, 400) + return + } + err = db.RegisterAuthn(user, credential) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + err = op.DelUserCache(user.Username) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + common.SuccessResp(c, "Registered Successfully") +} + +func DeleteAuthnLogin(c *gin.Context) { + user := c.MustGet("user").(*model.User) + type DeleteAuthnReq struct { + ID string `json:"id"` + } + var req DeleteAuthnReq + err := c.ShouldBind(&req) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + err = db.RemoveAuthn(user, req.ID) + err = op.DelUserCache(user.Username) + if err != nil { + common.ErrorResp(c, err, 400) + return + } + common.SuccessResp(c, "Deleted Successfully") +} + +func GetAuthnCredentials(c *gin.Context) { + type WebAuthnCredentials struct { + ID []byte `json:"id"` + FingerPrint string `json:"fingerprint"` + } + user := c.MustGet("user").(*model.User) + credentials := user.WebAuthnCredentials() + res := make([]WebAuthnCredentials, 0, len(credentials)) + for _, v := range credentials { + credential := WebAuthnCredentials{ + ID: v.ID, + FingerPrint: fmt.Sprintf("% X", v.Authenticator.AAGUID), + } + res = append(res, credential) + } + common.SuccessResp(c, res) +} diff --git a/server/middlewares/auth.go b/server/middlewares/auth.go new file mode 100644 index 0000000000000000000000000000000000000000..14f186be8bfc49340535c7901449ad1e06b78cd0 --- /dev/null +++ b/server/middlewares/auth.go @@ -0,0 +1,138 @@ +package middlewares + +import ( + "crypto/subtle" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" +) + +// Auth is a middleware that checks if the user is logged in. +// if token is empty, set user to guest +func Auth(c *gin.Context) { + token := c.GetHeader("Authorization") + if subtle.ConstantTimeCompare([]byte(token), []byte(setting.GetStr(conf.Token))) == 1 { + admin, err := op.GetAdmin() + if err != nil { + common.ErrorResp(c, err, 500) + c.Abort() + return + } + c.Set("user", admin) + log.Debugf("use admin token: %+v", admin) + c.Next() + return + } + if token == "" { + guest, err := op.GetGuest() + if err != nil { + common.ErrorResp(c, err, 500) + c.Abort() + return + } + if guest.Disabled { + common.ErrorStrResp(c, "Guest user is disabled, login please", 401) + c.Abort() + return + } + c.Set("user", guest) + log.Debugf("use empty token: %+v", guest) + c.Next() + return + } + userClaims, err := common.ParseToken(token) + if err != nil { + common.ErrorResp(c, err, 401) + c.Abort() + return + } + user, err := op.GetUserByName(userClaims.Username) + if err != nil { + common.ErrorResp(c, err, 401) + c.Abort() + return + } + // validate password timestamp + if userClaims.PwdTS != user.PwdTS { + common.ErrorStrResp(c, "Password has been changed, login please", 401) + c.Abort() + return + } + if user.Disabled { + common.ErrorStrResp(c, "Current user is disabled, replace please", 401) + c.Abort() + return + } + c.Set("user", user) + log.Debugf("use login token: %+v", user) + c.Next() +} + +func Authn(c *gin.Context) { + token := c.GetHeader("Authorization") + if subtle.ConstantTimeCompare([]byte(token), []byte(setting.GetStr(conf.Token))) == 1 { + admin, err := op.GetAdmin() + if err != nil { + common.ErrorResp(c, err, 500) + c.Abort() + return + } + c.Set("user", admin) + log.Debugf("use admin token: %+v", admin) + c.Next() + return + } + if token == "" { + guest, err := op.GetGuest() + if err != nil { + common.ErrorResp(c, err, 500) + c.Abort() + return + } + c.Set("user", guest) + log.Debugf("use empty token: %+v", guest) + c.Next() + return + } + userClaims, err := common.ParseToken(token) + if err != nil { + common.ErrorResp(c, err, 401) + c.Abort() + return + } + user, err := op.GetUserByName(userClaims.Username) + if err != nil { + common.ErrorResp(c, err, 401) + c.Abort() + return + } + // validate password timestamp + if userClaims.PwdTS != user.PwdTS { + common.ErrorStrResp(c, "Password has been changed, login please", 401) + c.Abort() + return + } + if user.Disabled { + common.ErrorStrResp(c, "Current user is disabled, replace please", 401) + c.Abort() + return + } + c.Set("user", user) + log.Debugf("use login token: %+v", user) + c.Next() +} + +func AuthAdmin(c *gin.Context) { + user := c.MustGet("user").(*model.User) + if !user.IsAdmin() { + common.ErrorStrResp(c, "You are not an admin", 403) + c.Abort() + } else { + c.Next() + } +} diff --git a/server/middlewares/check.go b/server/middlewares/check.go new file mode 100644 index 0000000000000000000000000000000000000000..aa3878e5ef3eb7738ce43f300489051e14c9f33a --- /dev/null +++ b/server/middlewares/check.go @@ -0,0 +1,30 @@ +package middlewares + +import ( + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" +) + +func StoragesLoaded(c *gin.Context) { + if conf.StoragesLoaded { + c.Next() + } else { + if utils.SliceContains([]string{"", "/", "/favicon.ico"}, c.Request.URL.Path) { + c.Next() + return + } + paths := []string{"/assets", "/images", "/streamer", "/static"} + for _, path := range paths { + if strings.HasPrefix(c.Request.URL.Path, path) { + c.Next() + return + } + } + common.ErrorStrResp(c, "Loading storage, please wait", 500) + c.Abort() + } +} diff --git a/server/middlewares/down.go b/server/middlewares/down.go new file mode 100644 index 0000000000000000000000000000000000000000..05e9dc856d8519b3978d44e9558193760478c4e2 --- /dev/null +++ b/server/middlewares/down.go @@ -0,0 +1,63 @@ +package middlewares + +import ( + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/setting" + + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/sign" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +func Down(c *gin.Context) { + rawPath := parsePath(c.Param("path")) + c.Set("path", rawPath) + meta, err := op.GetNearestMeta(rawPath) + if err != nil { + if !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500, true) + return + } + } + c.Set("meta", meta) + // verify sign + if needSign(meta, rawPath) { + s := c.Query("sign") + err = sign.Verify(rawPath, strings.TrimSuffix(s, "/")) + if err != nil { + common.ErrorResp(c, err, 401) + c.Abort() + return + } + } + c.Next() +} + +// TODO: implement +// path maybe contains # ? etc. +func parsePath(path string) string { + return utils.FixAndCleanPath(path) +} + +func needSign(meta *model.Meta, path string) bool { + if setting.GetBool(conf.SignAll) { + return true + } + if common.IsStorageSignEnabled(path) { + return true + } + if meta == nil || meta.Password == "" { + return false + } + if !meta.PSub && path != meta.Path { + return false + } + return true +} diff --git a/server/middlewares/fsup.go b/server/middlewares/fsup.go new file mode 100644 index 0000000000000000000000000000000000000000..2aa7fca6d036bf769b80d112b9917ed50e9501db --- /dev/null +++ b/server/middlewares/fsup.go @@ -0,0 +1,44 @@ +package middlewares + +import ( + "net/url" + stdpath "path" + + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" +) + +func FsUp(c *gin.Context) { + path := c.GetHeader("File-Path") + password := c.GetHeader("Password") + path, err := url.PathUnescape(path) + if err != nil { + common.ErrorResp(c, err, 400) + c.Abort() + return + } + user := c.MustGet("user").(*model.User) + path, err = user.JoinPath(path) + if err != nil { + common.ErrorResp(c, err, 403) + return + } + meta, err := op.GetNearestMeta(stdpath.Dir(path)) + if err != nil { + if !errors.Is(errors.Cause(err), errs.MetaNotFound) { + common.ErrorResp(c, err, 500, true) + c.Abort() + return + } + } + if !(common.CanAccess(user, meta, path, password) && (user.CanWrite() || common.CanWrite(meta, stdpath.Dir(path)))) { + common.ErrorResp(c, errs.PermissionDenied, 403) + c.Abort() + return + } + c.Next() +} diff --git a/server/middlewares/https.go b/server/middlewares/https.go new file mode 100644 index 0000000000000000000000000000000000000000..8c71eb71ff1fc6e5b0de5ddf99ed87fd27490962 --- /dev/null +++ b/server/middlewares/https.go @@ -0,0 +1,21 @@ +package middlewares + +import ( + "fmt" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/gin-gonic/gin" +) + +func ForceHttps(c *gin.Context) { + if c.Request.TLS == nil { + host := c.Request.Host + // change port to https port + host = strings.Replace(host, fmt.Sprintf(":%d", conf.Conf.Scheme.HttpPort), fmt.Sprintf(":%d", conf.Conf.Scheme.HttpsPort), 1) + c.Redirect(302, "https://"+host+c.Request.RequestURI) + c.Abort() + return + } + c.Next() +} diff --git a/server/middlewares/limit.go b/server/middlewares/limit.go new file mode 100644 index 0000000000000000000000000000000000000000..44c079b37e088d51e230bec8a3407f34df7382f3 --- /dev/null +++ b/server/middlewares/limit.go @@ -0,0 +1,16 @@ +package middlewares + +import ( + "github.com/gin-gonic/gin" +) + +func MaxAllowed(n int) gin.HandlerFunc { + sem := make(chan struct{}, n) + acquire := func() { sem <- struct{}{} } + release := func() { <-sem } + return func(c *gin.Context) { + acquire() + defer release() + c.Next() + } +} diff --git a/server/middlewares/search.go b/server/middlewares/search.go new file mode 100644 index 0000000000000000000000000000000000000000..5d84aadece898dd0a8c2ffc89b9c4e0abe0f2d35 --- /dev/null +++ b/server/middlewares/search.go @@ -0,0 +1,19 @@ +package middlewares + +import ( + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/server/common" + "github.com/gin-gonic/gin" +) + +func SearchIndex(c *gin.Context) { + mode := setting.GetStr(conf.SearchIndex) + if mode == "none" { + common.ErrorResp(c, errs.SearchNotAvailable, 500) + c.Abort() + } else { + c.Next() + } +} diff --git a/server/router.go b/server/router.go new file mode 100644 index 0000000000000000000000000000000000000000..b579fc93ca5f415554c24af2792ec2e5c98e1fe7 --- /dev/null +++ b/server/router.go @@ -0,0 +1,205 @@ +package server + +import ( + "math/rand" + "time" + + "github.com/alist-org/alist/v3/cmd/flags" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/message" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + "github.com/alist-org/alist/v3/server/handles" + "github.com/alist-org/alist/v3/server/middlewares" + "github.com/alist-org/alist/v3/server/static" + "github.com/gin-contrib/cors" + "github.com/gin-gonic/gin" +) + +// 生成指定长度的随机字符串 +func generateRandomString(length int) string { + rand.Seed(time.Now().UnixNano()) + + // 定义包含数字和字母的字符集 + charSet := "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + + var result string + for i := 0; i < length; i++ { + randomIndex := rand.Intn(len(charSet)) + result += string(charSet[randomIndex]) + } + + return result +} + +func Init(e *gin.Engine) { + randomStr := generateRandomString(10) + if !utils.SliceContains([]string{"", "/"}, conf.URL.Path) { + e.GET("/", func(c *gin.Context) { + c.Redirect(302, conf.URL.Path) + }) + } + Cors(e) + g := e.Group(conf.URL.Path) + if conf.Conf.Scheme.HttpPort != -1 && conf.Conf.Scheme.HttpsPort != -1 && conf.Conf.Scheme.ForceHttps { + e.Use(middlewares.ForceHttps) + } + g.Any("/ping", func(c *gin.Context) { + c.String(200, "pong") + }) + g.GET("/favicon.ico", handles.Favicon) + g.GET("/robots.txt", handles.Robots) + g.GET("/i/:link_name", handles.Plist) + common.SecretKey = []byte(conf.Conf.JwtSecret) + g.Use(middlewares.StoragesLoaded) + if conf.Conf.MaxConnections > 0 { + g.Use(middlewares.MaxAllowed(conf.Conf.MaxConnections)) + } + WebDav(g.Group("/dav")) + S3(g.Group("/s3")) + + g.GET("/d/*path", middlewares.Down, handles.Down) + g.GET("/p/*path", middlewares.Down, handles.Proxy) + g.HEAD("/d/*path", middlewares.Down, handles.Down) + g.HEAD("/p/*path", middlewares.Down, handles.Proxy) + + api := g.Group("/api") + auth := api.Group("", middlewares.Auth) + webauthn := api.Group("/authn", middlewares.Authn) + + api.POST("/auth/login", handles.Login) + api.POST("/auth/login/hash", handles.LoginHash) + api.POST("/auth/login/ldap", handles.LoginLdap) + auth.GET("/me", handles.CurrentUser) + auth.POST("/me/update", handles.UpdateCurrent) + auth.POST("/auth/2fa/generate", handles.Generate2FA) + auth.POST("/auth/2fa/verify", handles.Verify2FA) + + // auth + api.GET("/auth/sso", handles.SSOLoginRedirect) + api.GET("/auth/sso_callback", handles.SSOLoginCallback) + api.GET("/auth/get_sso_id", handles.SSOLoginCallback) + api.GET("/auth/sso_get_token", handles.SSOLoginCallback) + + //启动时生成随机字符串,用来验证唯一性 + api.GET("/instanceid", func(c *gin.Context) { + c.String(200, randomStr) + }) + + //webauthn + webauthn.GET("/webauthn_begin_registration", handles.BeginAuthnRegistration) + webauthn.POST("/webauthn_finish_registration", handles.FinishAuthnRegistration) + webauthn.GET("/webauthn_begin_login", handles.BeginAuthnLogin) + webauthn.POST("/webauthn_finish_login", handles.FinishAuthnLogin) + webauthn.POST("/delete_authn", handles.DeleteAuthnLogin) + webauthn.GET("/getcredentials", handles.GetAuthnCredentials) + + // no need auth + public := api.Group("/public") + public.Any("/settings", handles.PublicSettings) + public.Any("/offline_download_tools", handles.OfflineDownloadTools) + + _fs(auth.Group("/fs")) + admin(auth.Group("/admin", middlewares.AuthAdmin)) + if flags.Debug || flags.Dev { + debug(g.Group("/debug")) + } + static.Static(g, func(handlers ...gin.HandlerFunc) { + e.NoRoute(handlers...) + }) +} + +func admin(g *gin.RouterGroup) { + meta := g.Group("/meta") + meta.GET("/list", handles.ListMetas) + meta.GET("/get", handles.GetMeta) + meta.POST("/create", handles.CreateMeta) + meta.POST("/update", handles.UpdateMeta) + meta.POST("/delete", handles.DeleteMeta) + + user := g.Group("/user") + user.GET("/list", handles.ListUsers) + user.GET("/get", handles.GetUser) + user.POST("/create", handles.CreateUser) + user.POST("/update", handles.UpdateUser) + user.POST("/cancel_2fa", handles.Cancel2FAById) + user.POST("/delete", handles.DeleteUser) + user.POST("/del_cache", handles.DelUserCache) + + storage := g.Group("/storage") + storage.GET("/list", handles.ListStorages) + storage.GET("/get", handles.GetStorage) + storage.POST("/create", handles.CreateStorage) + storage.POST("/copy", handles.CopyStorage) + storage.POST("/update", handles.UpdateStorage) + storage.POST("/delete", handles.DeleteStorage) + storage.POST("/enable", handles.EnableStorage) + storage.POST("/disable", handles.DisableStorage) + storage.POST("/load_all", handles.LoadAllStorages) + + driver := g.Group("/driver") + driver.GET("/list", handles.ListDriverInfo) + driver.GET("/names", handles.ListDriverNames) + driver.GET("/info", handles.GetDriverInfo) + + setting := g.Group("/setting") + setting.GET("/get", handles.GetSetting) + setting.GET("/list", handles.ListSettings) + setting.POST("/save", handles.SaveSettings) + setting.POST("/delete", handles.DeleteSetting) + setting.POST("/reset_token", handles.ResetToken) + setting.POST("/set_aria2", handles.SetAria2) + setting.POST("/set_qbit", handles.SetQbittorrent) + + task := g.Group("/task") + handles.SetupTaskRoute(task) + + ms := g.Group("/message") + ms.POST("/get", message.HttpInstance.GetHandle) + ms.POST("/send", message.HttpInstance.SendHandle) + + index := g.Group("/index") + index.POST("/build", middlewares.SearchIndex, handles.BuildIndex) + index.POST("/update", middlewares.SearchIndex, handles.UpdateIndex) + index.POST("/stop", middlewares.SearchIndex, handles.StopIndex) + index.POST("/clear", middlewares.SearchIndex, handles.ClearIndex) + index.GET("/progress", middlewares.SearchIndex, handles.GetProgress) +} + +func _fs(g *gin.RouterGroup) { + g.Any("/list", handles.FsList) + g.Any("/search", middlewares.SearchIndex, handles.Search) + g.Any("/get", handles.FsGet) + g.Any("/other", handles.FsOther) + g.Any("/dirs", handles.FsDirs) + g.POST("/mkdir", handles.FsMkdir) + g.POST("/rename", handles.FsRename) + g.POST("/batch_rename", handles.FsBatchRename) + g.POST("/regex_rename", handles.FsRegexRename) + g.POST("/move", handles.FsMove) + g.POST("/recursive_move", handles.FsRecursiveMove) + g.POST("/copy", handles.FsCopy) + g.POST("/copy_item", handles.FsCopyItem) + g.POST("/remove", handles.FsRemove) + g.POST("/remove_empty_directory", handles.FsRemoveEmptyDirectory) + g.PUT("/put", middlewares.FsUp, handles.FsStream) + g.PUT("/form", middlewares.FsUp, handles.FsForm) + g.POST("/link", middlewares.AuthAdmin, handles.Link) + //g.POST("/add_aria2", handles.AddOfflineDownload) + //g.POST("/add_qbit", handles.AddQbittorrent) + g.POST("/add_offline_download", handles.AddOfflineDownload) +} + +func Cors(r *gin.Engine) { + config := cors.DefaultConfig() + //config.AllowAllOrigins = true + config.AllowOrigins = conf.Conf.Cors.AllowOrigins + config.AllowHeaders = conf.Conf.Cors.AllowHeaders + config.AllowMethods = conf.Conf.Cors.AllowMethods + r.Use(cors.New(config)) +} + +func InitS3(e *gin.Engine) { + Cors(e) + S3Server(e.Group("/")) +} diff --git a/server/s3.go b/server/s3.go new file mode 100644 index 0000000000000000000000000000000000000000..21b95527ded0ee14f9522e81ee7da0d9432f5da5 --- /dev/null +++ b/server/s3.go @@ -0,0 +1,39 @@ +package server + +import ( + "context" + "path" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/server/common" + "github.com/alist-org/alist/v3/server/s3" + "github.com/gin-gonic/gin" +) + +func S3(g *gin.RouterGroup) { + if !conf.Conf.S3.Enable { + g.Any("/*path", func(c *gin.Context) { + common.ErrorStrResp(c, "S3 server is not enabled", 403) + }) + return + } + if conf.Conf.S3.Port != -1 { + g.Any("/*path", func(c *gin.Context) { + common.ErrorStrResp(c, "S3 server bound to single port", 403) + }) + return + } + h, _ := s3.NewServer(context.Background()) + + g.Any("/*path", func(c *gin.Context) { + adjustedPath := strings.TrimPrefix(c.Request.URL.Path, path.Join(conf.URL.Path, "/s3")) + c.Request.URL.Path = adjustedPath + gin.WrapH(h)(c) + }) +} + +func S3Server(g *gin.RouterGroup) { + h, _ := s3.NewServer(context.Background()) + g.Any("/*path", gin.WrapH(h)) +} diff --git a/server/s3/backend.go b/server/s3/backend.go new file mode 100644 index 0000000000000000000000000000000000000000..75c6b28b1ef2dd6c1d767de020ea7437cf87489e --- /dev/null +++ b/server/s3/backend.go @@ -0,0 +1,432 @@ +// Credits: https://pkg.go.dev/github.com/rclone/rclone@v1.65.2/cmd/serve/s3 +// Package s3 implements a fake s3 server for alist +package s3 + +import ( + "context" + "encoding/hex" + "fmt" + "io" + "path" + "strings" + "sync" + "time" + + "github.com/Mikubill/gofakes3" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/stream" + "github.com/alist-org/alist/v3/pkg/http_range" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/ncw/swift/v2" +) + +var ( + emptyPrefix = &gofakes3.Prefix{} + timeFormat = "Mon, 2 Jan 2006 15:04:05.999999999 GMT" +) + +// s3Backend implements the gofacess3.Backend interface to make an S3 +// backend for gofakes3 +type s3Backend struct { + meta *sync.Map +} + +// newBackend creates a new SimpleBucketBackend. +func newBackend() gofakes3.Backend { + return &s3Backend{ + meta: new(sync.Map), + } +} + +// ListBuckets always returns the default bucket. +func (b *s3Backend) ListBuckets() ([]gofakes3.BucketInfo, error) { + buckets, err := getAndParseBuckets() + if err != nil { + return nil, err + } + var response []gofakes3.BucketInfo + ctx := context.Background() + for _, b := range buckets { + node, _ := fs.Get(ctx, b.Path, &fs.GetArgs{}) + response = append(response, gofakes3.BucketInfo{ + // Name: gofakes3.URLEncode(b.Name), + Name: b.Name, + CreationDate: gofakes3.NewContentTime(node.ModTime()), + }) + } + return response, nil +} + +// ListBucket lists the objects in the given bucket. +func (b *s3Backend) ListBucket(bucketName string, prefix *gofakes3.Prefix, page gofakes3.ListBucketPage) (*gofakes3.ObjectList, error) { + bucket, err := getBucketByName(bucketName) + if err != nil { + return nil, err + } + bucketPath := bucket.Path + + if prefix == nil { + prefix = emptyPrefix + } + + // workaround + if strings.TrimSpace(prefix.Prefix) == "" { + prefix.HasPrefix = false + } + if strings.TrimSpace(prefix.Delimiter) == "" { + prefix.HasDelimiter = false + } + + response := gofakes3.NewObjectList() + path, remaining := prefixParser(prefix) + + err = b.entryListR(bucketPath, path, remaining, prefix.HasDelimiter, response) + if err == gofakes3.ErrNoSuchKey { + // AWS just returns an empty list + response = gofakes3.NewObjectList() + } else if err != nil { + return nil, err + } + + return b.pager(response, page) +} + +// HeadObject returns the fileinfo for the given object name. +// +// Note that the metadata is not supported yet. +func (b *s3Backend) HeadObject(bucketName, objectName string) (*gofakes3.Object, error) { + ctx := context.Background() + bucket, err := getBucketByName(bucketName) + if err != nil { + return nil, err + } + bucketPath := bucket.Path + + fp := path.Join(bucketPath, objectName) + fmeta, _ := op.GetNearestMeta(fp) + node, err := fs.Get(context.WithValue(ctx, "meta", fmeta), fp, &fs.GetArgs{}) + if err != nil { + return nil, gofakes3.KeyNotFound(objectName) + } + + if node.IsDir() { + return nil, gofakes3.KeyNotFound(objectName) + } + + size := node.GetSize() + // hash := getFileHashByte(fobj) + + meta := map[string]string{ + "Last-Modified": node.ModTime().Format(timeFormat), + "Content-Type": utils.GetMimeType(fp), + } + + if val, ok := b.meta.Load(fp); ok { + metaMap := val.(map[string]string) + for k, v := range metaMap { + meta[k] = v + } + } + + return &gofakes3.Object{ + Name: objectName, + // Hash: hash, + Metadata: meta, + Size: size, + Contents: noOpReadCloser{}, + }, nil +} + +// GetObject fetchs the object from the filesystem. +func (b *s3Backend) GetObject(bucketName, objectName string, rangeRequest *gofakes3.ObjectRangeRequest) (obj *gofakes3.Object, err error) { + ctx := context.Background() + bucket, err := getBucketByName(bucketName) + if err != nil { + return nil, err + } + bucketPath := bucket.Path + + fp := path.Join(bucketPath, objectName) + fmeta, _ := op.GetNearestMeta(fp) + node, err := fs.Get(context.WithValue(ctx, "meta", fmeta), fp, &fs.GetArgs{}) + if err != nil { + return nil, gofakes3.KeyNotFound(objectName) + } + + if node.IsDir() { + return nil, gofakes3.KeyNotFound(objectName) + } + + link, file, err := fs.Link(ctx, fp, model.LinkArgs{}) + if err != nil { + return nil, err + } + + size := file.GetSize() + rnge, err := rangeRequest.Range(size) + if err != nil { + return nil, err + } + + if link.RangeReadCloser == nil && link.MFile == nil && len(link.URL) == 0 { + return nil, fmt.Errorf("the remote storage driver need to be enhanced to support s3") + } + remoteFileSize := file.GetSize() + remoteClosers := utils.EmptyClosers() + rangeReaderFunc := func(ctx context.Context, start, length int64) (io.ReadCloser, error) { + if length >= 0 && start+length >= remoteFileSize { + length = -1 + } + rrc := link.RangeReadCloser + if len(link.URL) > 0 { + + rangedRemoteLink := &model.Link{ + URL: link.URL, + Header: link.Header, + } + var converted, err = stream.GetRangeReadCloserFromLink(remoteFileSize, rangedRemoteLink) + if err != nil { + return nil, err + } + rrc = converted + } + if rrc != nil { + remoteReader, err := rrc.RangeRead(ctx, http_range.Range{Start: start, Length: length}) + remoteClosers.AddClosers(rrc.GetClosers()) + if err != nil { + return nil, err + } + return remoteReader, nil + } + if link.MFile != nil { + _, err := link.MFile.Seek(start, io.SeekStart) + if err != nil { + return nil, err + } + //remoteClosers.Add(remoteLink.MFile) + //keep reuse same MFile and close at last. + remoteClosers.Add(link.MFile) + return io.NopCloser(link.MFile), nil + } + return nil, errs.NotSupport + } + + var rdr io.ReadCloser + if rnge != nil { + rdr, err = rangeReaderFunc(ctx, rnge.Start, rnge.Length) + if err != nil { + return nil, err + } + } else { + rdr, err = rangeReaderFunc(ctx, 0, -1) + if err != nil { + return nil, err + } + } + + meta := map[string]string{ + "Last-Modified": node.ModTime().Format(timeFormat), + "Content-Type": utils.GetMimeType(fp), + } + + if val, ok := b.meta.Load(fp); ok { + metaMap := val.(map[string]string) + for k, v := range metaMap { + meta[k] = v + } + } + + return &gofakes3.Object{ + // Name: gofakes3.URLEncode(objectName), + Name: objectName, + // Hash: "", + Metadata: meta, + Size: size, + Range: rnge, + Contents: rdr, + }, nil +} + +// TouchObject creates or updates meta on specified object. +func (b *s3Backend) TouchObject(fp string, meta map[string]string) (result gofakes3.PutObjectResult, err error) { + //TODO: implement + return result, gofakes3.ErrNotImplemented +} + +// PutObject creates or overwrites the object with the given name. +func (b *s3Backend) PutObject( + bucketName, objectName string, + meta map[string]string, + input io.Reader, size int64, +) (result gofakes3.PutObjectResult, err error) { + ctx := context.Background() + bucket, err := getBucketByName(bucketName) + if err != nil { + return result, err + } + bucketPath := bucket.Path + + fp := path.Join(bucketPath, objectName) + reqPath := path.Dir(fp) + fmeta, _ := op.GetNearestMeta(fp) + _, err = fs.Get(context.WithValue(ctx, "meta", fmeta), reqPath, &fs.GetArgs{}) + if err != nil { + return result, gofakes3.KeyNotFound(objectName) + } + + var ti time.Time + + if val, ok := meta["X-Amz-Meta-Mtime"]; ok { + ti, _ = swift.FloatStringToTime(val) + } + + if val, ok := meta["mtime"]; ok { + ti, _ = swift.FloatStringToTime(val) + } + + obj := model.Object{ + Name: path.Base(fp), + Size: size, + Modified: ti, + Ctime: time.Now(), + } + stream := &stream.FileStream{ + Obj: &obj, + Reader: input, + Mimetype: meta["Content-Type"], + } + + err = fs.PutDirectly(ctx, reqPath, stream) + if err != nil { + return result, err + } + + if err := stream.Close(); err != nil { + // remove file when close error occurred (FsPutErr) + _ = fs.Remove(ctx, fp) + return result, err + } + + b.meta.Store(fp, meta) + + return result, nil +} + +// DeleteMulti deletes multiple objects in a single request. +func (b *s3Backend) DeleteMulti(bucketName string, objects ...string) (result gofakes3.MultiDeleteResult, rerr error) { + for _, object := range objects { + if err := b.deleteObject(bucketName, object); err != nil { + utils.Log.Errorf("serve s3", "delete object failed: %v", err) + result.Error = append(result.Error, gofakes3.ErrorResult{ + Code: gofakes3.ErrInternal, + Message: gofakes3.ErrInternal.Message(), + Key: object, + }) + } else { + result.Deleted = append(result.Deleted, gofakes3.ObjectID{ + Key: object, + }) + } + } + + return result, nil +} + +// DeleteObject deletes the object with the given name. +func (b *s3Backend) DeleteObject(bucketName, objectName string) (result gofakes3.ObjectDeleteResult, rerr error) { + return result, b.deleteObject(bucketName, objectName) +} + +// deleteObject deletes the object from the filesystem. +func (b *s3Backend) deleteObject(bucketName, objectName string) error { + ctx := context.Background() + bucket, err := getBucketByName(bucketName) + if err != nil { + return err + } + bucketPath := bucket.Path + + fp := path.Join(bucketPath, objectName) + fmeta, _ := op.GetNearestMeta(fp) + // S3 does not report an error when attemping to delete a key that does not exist, so + // we need to skip IsNotExist errors. + if _, err := fs.Get(context.WithValue(ctx, "meta", fmeta), fp, &fs.GetArgs{}); err != nil && !errs.IsObjectNotFound(err) { + return err + } + + fs.Remove(ctx, fp) + return nil +} + +// CreateBucket creates a new bucket. +func (b *s3Backend) CreateBucket(name string) error { + return gofakes3.ErrNotImplemented +} + +// DeleteBucket deletes the bucket with the given name. +func (b *s3Backend) DeleteBucket(name string) error { + return gofakes3.ErrNotImplemented +} + +// BucketExists checks if the bucket exists. +func (b *s3Backend) BucketExists(name string) (exists bool, err error) { + buckets, err := getAndParseBuckets() + if err != nil { + return false, err + } + for _, b := range buckets { + if b.Name == name { + return true, nil + } + } + return false, nil +} + +// CopyObject copy specified object from srcKey to dstKey. +func (b *s3Backend) CopyObject(srcBucket, srcKey, dstBucket, dstKey string, meta map[string]string) (result gofakes3.CopyObjectResult, err error) { + if srcBucket == dstBucket && srcKey == dstKey { + //TODO: update meta + return result, nil + } + + ctx := context.Background() + srcB, err := getBucketByName(srcBucket) + if err != nil { + return result, err + } + srcBucketPath := srcB.Path + + srcFp := path.Join(srcBucketPath, srcKey) + fmeta, _ := op.GetNearestMeta(srcFp) + srcNode, err := fs.Get(context.WithValue(ctx, "meta", fmeta), srcFp, &fs.GetArgs{}) + + c, err := b.GetObject(srcBucket, srcKey, nil) + if err != nil { + return + } + defer func() { + _ = c.Contents.Close() + }() + + for k, v := range c.Metadata { + if _, found := meta[k]; !found && k != "X-Amz-Acl" { + meta[k] = v + } + } + if _, ok := meta["mtime"]; !ok { + meta["mtime"] = swift.TimeToFloatString(srcNode.ModTime()) + } + + _, err = b.PutObject(dstBucket, dstKey, meta, c.Contents, c.Size) + if err != nil { + return + } + + return gofakes3.CopyObjectResult{ + ETag: `"` + hex.EncodeToString(c.Hash) + `"`, + LastModified: gofakes3.NewContentTime(srcNode.ModTime()), + }, nil +} diff --git a/server/s3/ioutils.go b/server/s3/ioutils.go new file mode 100644 index 0000000000000000000000000000000000000000..6b49cacc7087c281dd4bdf2360aa3a8f73ba23eb --- /dev/null +++ b/server/s3/ioutils.go @@ -0,0 +1,36 @@ +// Credits: https://pkg.go.dev/github.com/rclone/rclone@v1.65.2/cmd/serve/s3 +// Package s3 implements a fake s3 server for alist +package s3 + +import "io" + +type noOpReadCloser struct{} + +type readerWithCloser struct { + io.Reader + closer func() error +} + +var _ io.ReadCloser = &readerWithCloser{} + +func (d noOpReadCloser) Read(b []byte) (n int, err error) { + return 0, io.EOF +} + +func (d noOpReadCloser) Close() error { + return nil +} + +func limitReadCloser(rdr io.Reader, closer func() error, sz int64) io.ReadCloser { + return &readerWithCloser{ + Reader: io.LimitReader(rdr, sz), + closer: closer, + } +} + +func (rwc *readerWithCloser) Close() error { + if rwc.closer != nil { + return rwc.closer() + } + return nil +} diff --git a/server/s3/list.go b/server/s3/list.go new file mode 100644 index 0000000000000000000000000000000000000000..bce870ca9ed8c9f6d02889e2e7c5ddb50d764672 --- /dev/null +++ b/server/s3/list.go @@ -0,0 +1,53 @@ +// Credits: https://pkg.go.dev/github.com/rclone/rclone@v1.65.2/cmd/serve/s3 +// Package s3 implements a fake s3 server for alist +package s3 + +import ( + "path" + "strings" + + "github.com/Mikubill/gofakes3" +) + +func (b *s3Backend) entryListR(bucket, fdPath, name string, addPrefix bool, response *gofakes3.ObjectList) error { + fp := path.Join(bucket, fdPath) + + dirEntries, err := getDirEntries(fp) + if err != nil { + return err + } + + for _, entry := range dirEntries { + object := entry.GetName() + + // workround for control-chars detect + objectPath := path.Join(fdPath, object) + + if !strings.HasPrefix(object, name) { + continue + } + + if entry.IsDir() { + if addPrefix { + // response.AddPrefix(gofakes3.URLEncode(objectPath)) + response.AddPrefix(objectPath) + continue + } + err := b.entryListR(bucket, path.Join(fdPath, object), "", false, response) + if err != nil { + return err + } + } else { + item := &gofakes3.Content{ + // Key: gofakes3.URLEncode(objectPath), + Key: objectPath, + LastModified: gofakes3.NewContentTime(entry.ModTime()), + ETag: getFileHash(entry), + Size: entry.GetSize(), + StorageClass: gofakes3.StorageStandard, + } + response.Add(item) + } + } + return nil +} diff --git a/server/s3/logger.go b/server/s3/logger.go new file mode 100644 index 0000000000000000000000000000000000000000..7566fa8a116fdbaaab2628fb8e7d05f975bfc1cf --- /dev/null +++ b/server/s3/logger.go @@ -0,0 +1,27 @@ +// Credits: https://pkg.go.dev/github.com/rclone/rclone@v1.65.2/cmd/serve/s3 +// Package s3 implements a fake s3 server for alist +package s3 + +import ( + "fmt" + + "github.com/Mikubill/gofakes3" + "github.com/alist-org/alist/v3/pkg/utils" +) + +// logger output formatted message +type logger struct{} + +// print log message +func (l logger) Print(level gofakes3.LogLevel, v ...interface{}) { + switch level { + default: + fallthrough + case gofakes3.LogErr: + utils.Log.Errorf("serve s3: %s", fmt.Sprintln(v...)) + case gofakes3.LogWarn: + utils.Log.Infof("serve s3: %s", fmt.Sprintln(v...)) + case gofakes3.LogInfo: + utils.Log.Debugf("serve s3: %s", fmt.Sprintln(v...)) + } +} diff --git a/server/s3/pager.go b/server/s3/pager.go new file mode 100644 index 0000000000000000000000000000000000000000..3268b0ca23422a206cdf472d3de90e68ef76d628 --- /dev/null +++ b/server/s3/pager.go @@ -0,0 +1,67 @@ +// Credits: https://pkg.go.dev/github.com/rclone/rclone@v1.65.2/cmd/serve/s3 +// Package s3 implements a fake s3 server for alist +package s3 + +import ( + "sort" + + "github.com/Mikubill/gofakes3" +) + +// pager splits the object list into smulitply pages. +func (db *s3Backend) pager(list *gofakes3.ObjectList, page gofakes3.ListBucketPage) (*gofakes3.ObjectList, error) { + // sort by alphabet + sort.Slice(list.CommonPrefixes, func(i, j int) bool { + return list.CommonPrefixes[i].Prefix < list.CommonPrefixes[j].Prefix + }) + // sort by modtime + sort.Slice(list.Contents, func(i, j int) bool { + return list.Contents[i].LastModified.Before(list.Contents[j].LastModified.Time) + }) + tokens := page.MaxKeys + if tokens == 0 { + tokens = 1000 + } + if page.HasMarker { + for i, obj := range list.Contents { + if obj.Key == page.Marker { + list.Contents = list.Contents[i+1:] + break + } + } + for i, obj := range list.CommonPrefixes { + if obj.Prefix == page.Marker { + list.CommonPrefixes = list.CommonPrefixes[i+1:] + break + } + } + } + + response := gofakes3.NewObjectList() + for _, obj := range list.CommonPrefixes { + if tokens <= 0 { + break + } + response.AddPrefix(obj.Prefix) + tokens-- + } + + for _, obj := range list.Contents { + if tokens <= 0 { + break + } + response.Add(obj) + tokens-- + } + + if len(list.CommonPrefixes)+len(list.Contents) > int(page.MaxKeys) { + response.IsTruncated = true + if len(response.Contents) > 0 { + response.NextMarker = response.Contents[len(response.Contents)-1].Key + } else { + response.NextMarker = response.CommonPrefixes[len(response.CommonPrefixes)-1].Prefix + } + } + + return response, nil +} diff --git a/server/s3/server.go b/server/s3/server.go new file mode 100644 index 0000000000000000000000000000000000000000..19df735fb5d0217cff22c8e9b8607d3df2ad1335 --- /dev/null +++ b/server/s3/server.go @@ -0,0 +1,27 @@ +// Credits: https://pkg.go.dev/github.com/rclone/rclone@v1.65.2/cmd/serve/s3 +// Package s3 implements a fake s3 server for alist +package s3 + +import ( + "context" + "math/rand" + "net/http" + + "github.com/Mikubill/gofakes3" +) + +// Make a new S3 Server to serve the remote +func NewServer(ctx context.Context) (h http.Handler, err error) { + var newLogger logger + faker := gofakes3.New( + newBackend(), + // gofakes3.WithHostBucket(!opt.pathBucketMode), + gofakes3.WithLogger(newLogger), + gofakes3.WithRequestID(rand.Uint64()), + gofakes3.WithoutVersioning(), + gofakes3.WithV4Auth(authlistResolver()), + gofakes3.WithIntegrityCheck(true), // Check Content-MD5 if supplied + ) + + return faker.Server(), nil +} diff --git a/server/s3/utils.go b/server/s3/utils.go new file mode 100644 index 0000000000000000000000000000000000000000..98c271f76a36af44476d9a0f73eb922b35b59c1b --- /dev/null +++ b/server/s3/utils.go @@ -0,0 +1,160 @@ +// Credits: https://pkg.go.dev/github.com/rclone/rclone@v1.65.2/cmd/serve/s3 +// Package s3 implements a fake s3 server for alist +package s3 + +import ( + "context" + "encoding/json" + "strings" + + "github.com/Mikubill/gofakes3" + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" +) + +type Bucket struct { + Name string `json:"name"` + Path string `json:"path"` +} + +func getAndParseBuckets() ([]Bucket, error) { + var res []Bucket + err := json.Unmarshal([]byte(setting.GetStr(conf.S3Buckets)), &res) + return res, err +} + +func getBucketByName(name string) (Bucket, error) { + buckets, err := getAndParseBuckets() + if err != nil { + return Bucket{}, err + } + for _, b := range buckets { + if b.Name == name { + return b, nil + } + } + return Bucket{}, gofakes3.BucketNotFound(name) +} + +func getDirEntries(path string) ([]model.Obj, error) { + ctx := context.Background() + meta, _ := op.GetNearestMeta(path) + fi, err := fs.Get(context.WithValue(ctx, "meta", meta), path, &fs.GetArgs{}) + if errs.IsNotFoundError(err) { + return nil, gofakes3.ErrNoSuchKey + } else if err != nil { + return nil, gofakes3.ErrNoSuchKey + } + + if !fi.IsDir() { + return nil, gofakes3.ErrNoSuchKey + } + + dirEntries, err := fs.List(context.WithValue(ctx, "meta", meta), path, &fs.ListArgs{}) + if err != nil { + return nil, err + } + + return dirEntries, nil +} + +// func getFileHashByte(node interface{}) []byte { +// b, err := hex.DecodeString(getFileHash(node)) +// if err != nil { +// return nil +// } +// return b +// } + +func getFileHash(node interface{}) string { + // var o fs.Object + + // switch b := node.(type) { + // case vfs.Node: + // fsObj, ok := b.DirEntry().(fs.Object) + // if !ok { + // fs.Debugf("serve s3", "File uploading - reading hash from VFS cache") + // in, err := b.Open(os.O_RDONLY) + // if err != nil { + // return "" + // } + // defer func() { + // _ = in.Close() + // }() + // h, err := hash.NewMultiHasherTypes(hash.NewHashSet(hash.MD5)) + // if err != nil { + // return "" + // } + // _, err = io.Copy(h, in) + // if err != nil { + // return "" + // } + // return h.Sums()[hash.MD5] + // } + // o = fsObj + // case fs.Object: + // o = b + // } + + // hash, err := o.Hash(context.Background(), hash.MD5) + // if err != nil { + // return "" + // } + // return hash + return "" +} + +func prefixParser(p *gofakes3.Prefix) (path, remaining string) { + idx := strings.LastIndexByte(p.Prefix, '/') + if idx < 0 { + return "", p.Prefix + } + return p.Prefix[:idx], p.Prefix[idx+1:] +} + +// // FIXME this could be implemented by VFS.MkdirAll() +// func mkdirRecursive(path string, VFS *vfs.VFS) error { +// path = strings.Trim(path, "/") +// dirs := strings.Split(path, "/") +// dir := "" +// for _, d := range dirs { +// dir += "/" + d +// if _, err := VFS.Stat(dir); err != nil { +// err := VFS.Mkdir(dir, 0777) +// if err != nil { +// return err +// } +// } +// } +// return nil +// } + +// func rmdirRecursive(p string, VFS *vfs.VFS) { +// dir := path.Dir(p) +// if !strings.ContainsAny(dir, "/\\") { +// // might be bucket(root) +// return +// } +// if _, err := VFS.Stat(dir); err == nil { +// err := VFS.Remove(dir) +// if err != nil { +// return +// } +// rmdirRecursive(dir, VFS) +// } +// } + +func authlistResolver() map[string]string { + s3accesskeyid := setting.GetStr(conf.S3AccessKeyId) + s3secretaccesskey := setting.GetStr(conf.S3SecretAccessKey) + if s3accesskeyid == "" && s3secretaccesskey == "" { + return nil + } + authList := make(map[string]string) + authList[s3accesskeyid] = s3secretaccesskey + return authList +} diff --git a/server/static/config.go b/server/static/config.go new file mode 100644 index 0000000000000000000000000000000000000000..7a8250a9046d5403007499512c78f19a8f32201d --- /dev/null +++ b/server/static/config.go @@ -0,0 +1,27 @@ +package static + +import ( + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/pkg/utils" +) + +type SiteConfig struct { + BasePath string + Cdn string +} + +func getSiteConfig() SiteConfig { + siteConfig := SiteConfig{ + BasePath: conf.URL.Path, + Cdn: strings.ReplaceAll(strings.TrimSuffix(conf.Conf.Cdn, "/"), "$version", conf.WebVersion), + } + if siteConfig.BasePath != "" { + siteConfig.BasePath = utils.FixAndCleanPath(siteConfig.BasePath) + } + if siteConfig.Cdn == "" { + siteConfig.Cdn = strings.TrimSuffix(siteConfig.BasePath, "/") + } + return siteConfig +} diff --git a/server/static/static.go b/server/static/static.go new file mode 100644 index 0000000000000000000000000000000000000000..ec16014c22b881863c42534b81acaf267f574874 --- /dev/null +++ b/server/static/static.go @@ -0,0 +1,115 @@ +package static + +import ( + "errors" + "fmt" + "io" + "io/fs" + "net/http" + "os" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/public" + "github.com/gin-gonic/gin" +) + +var static fs.FS + +func initStatic() { + if conf.Conf.DistDir == "" { + dist, err := fs.Sub(public.Public, "dist") + if err != nil { + utils.Log.Fatalf("failed to read dist dir") + } + static = dist + return + } + static = os.DirFS(conf.Conf.DistDir) +} + +func initIndex() { + indexFile, err := static.Open("index.html") + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + utils.Log.Fatalf("index.html not exist, you may forget to put dist of frontend to public/dist") + } + utils.Log.Fatalf("failed to read index.html: %v", err) + } + defer func() { + _ = indexFile.Close() + }() + index, err := io.ReadAll(indexFile) + if err != nil { + utils.Log.Fatalf("failed to read dist/index.html") + } + conf.RawIndexHtml = string(index) + siteConfig := getSiteConfig() + replaceMap := map[string]string{ + "cdn: undefined": fmt.Sprintf("cdn: '%s'", siteConfig.Cdn), + "base_path: undefined": fmt.Sprintf("base_path: '%s'", siteConfig.BasePath), + } + for k, v := range replaceMap { + conf.RawIndexHtml = strings.Replace(conf.RawIndexHtml, k, v, 1) + } + UpdateIndex() +} + +func UpdateIndex() { + favicon := setting.GetStr(conf.Favicon) + title := setting.GetStr(conf.SiteTitle) + customizeHead := setting.GetStr(conf.CustomizeHead) + customizeBody := setting.GetStr(conf.CustomizeBody) + mainColor := setting.GetStr(conf.MainColor) + conf.ManageHtml = conf.RawIndexHtml + replaceMap1 := map[string]string{ + "https://jsd.nn.ci/gh/alist-org/logo@main/logo.svg": favicon, + "Loading...": title, + "main_color: undefined": fmt.Sprintf("main_color: '%s'", mainColor), + } + for k, v := range replaceMap1 { + conf.ManageHtml = strings.Replace(conf.ManageHtml, k, v, 1) + } + conf.IndexHtml = conf.ManageHtml + replaceMap2 := map[string]string{ + "": customizeHead, + "": customizeBody, + } + for k, v := range replaceMap2 { + conf.IndexHtml = strings.Replace(conf.IndexHtml, k, v, 1) + } +} + +func Static(r *gin.RouterGroup, noRoute func(handlers ...gin.HandlerFunc)) { + initStatic() + initIndex() + folders := []string{"assets", "images", "streamer", "static"} + r.Use(func(c *gin.Context) { + for i := range folders { + if strings.HasPrefix(c.Request.RequestURI, fmt.Sprintf("/%s/", folders[i])) { + c.Header("Cache-Control", "public, max-age=15552000") + } + } + }) + for i, folder := range folders { + sub, err := fs.Sub(static, folder) + if err != nil { + utils.Log.Fatalf("can't find folder: %s", folder) + } + r.StaticFS(fmt.Sprintf("/%s/", folders[i]), http.FS(sub)) + } + + noRoute(func(c *gin.Context) { + c.Header("Content-Type", "text/html") + c.Status(200) + if strings.HasPrefix(c.Request.URL.Path, "/@manage") { + _, _ = c.Writer.WriteString(conf.ManageHtml) + } else { + _, _ = c.Writer.WriteString(conf.IndexHtml) + } + c.Writer.Flush() + c.Writer.WriteHeaderNow() + }) +} diff --git a/server/webdav.go b/server/webdav.go new file mode 100644 index 0000000000000000000000000000000000000000..2b5c9618b861d0c27cc297b872c4287360d2bf74 --- /dev/null +++ b/server/webdav.go @@ -0,0 +1,114 @@ +package server + +import ( + "context" + "crypto/subtle" + "net/http" + "path" + "strings" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/setting" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/webdav" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" +) + +var handler *webdav.Handler + +func WebDav(dav *gin.RouterGroup) { + handler = &webdav.Handler{ + Prefix: path.Join(conf.URL.Path, "/dav"), + LockSystem: webdav.NewMemLS(), + Logger: func(request *http.Request, err error) { + log.Errorf("%s %s %+v", request.Method, request.URL.Path, err) + }, + } + dav.Use(WebDAVAuth) + dav.Any("/*path", ServeWebDAV) + dav.Any("", ServeWebDAV) + dav.Handle("PROPFIND", "/*path", ServeWebDAV) + dav.Handle("PROPFIND", "", ServeWebDAV) + dav.Handle("MKCOL", "/*path", ServeWebDAV) + dav.Handle("LOCK", "/*path", ServeWebDAV) + dav.Handle("UNLOCK", "/*path", ServeWebDAV) + dav.Handle("PROPPATCH", "/*path", ServeWebDAV) + dav.Handle("COPY", "/*path", ServeWebDAV) + dav.Handle("MOVE", "/*path", ServeWebDAV) +} + +func ServeWebDAV(c *gin.Context) { + user := c.MustGet("user").(*model.User) + ctx := context.WithValue(c.Request.Context(), "user", user) + handler.ServeHTTP(c.Writer, c.Request.WithContext(ctx)) +} + +func WebDAVAuth(c *gin.Context) { + guest, _ := op.GetGuest() + username, password, ok := c.Request.BasicAuth() + if !ok { + bt := c.GetHeader("Authorization") + log.Debugf("[webdav auth] token: %s", bt) + if strings.HasPrefix(bt, "Bearer") { + bt = strings.TrimPrefix(bt, "Bearer ") + token := setting.GetStr(conf.Token) + if token != "" && subtle.ConstantTimeCompare([]byte(bt), []byte(token)) == 1 { + admin, err := op.GetAdmin() + if err != nil { + log.Errorf("[webdav auth] failed get admin user: %+v", err) + c.Status(http.StatusInternalServerError) + c.Abort() + return + } + c.Set("user", admin) + c.Next() + return + } + } + if c.Request.Method == "OPTIONS" { + c.Set("user", guest) + c.Next() + return + } + c.Writer.Header()["WWW-Authenticate"] = []string{`Basic realm="alist"`} + c.Status(http.StatusUnauthorized) + c.Abort() + return + } + user, err := op.GetUserByName(username) + if err != nil || user.ValidateRawPassword(password) != nil { + if c.Request.Method == "OPTIONS" { + c.Set("user", guest) + c.Next() + return + } + c.Status(http.StatusUnauthorized) + c.Abort() + return + } + if user.Disabled || !user.CanWebdavRead() { + if c.Request.Method == "OPTIONS" { + c.Set("user", guest) + c.Next() + return + } + c.Status(http.StatusForbidden) + c.Abort() + return + } + if !user.CanWebdavManage() && utils.SliceContains([]string{"PUT", "DELETE", "PROPPATCH", "MKCOL", "COPY", "MOVE"}, c.Request.Method) { + if c.Request.Method == "OPTIONS" { + c.Set("user", guest) + c.Next() + return + } + c.Status(http.StatusForbidden) + c.Abort() + return + } + c.Set("user", user) + c.Next() +} diff --git a/server/webdav/buffered_response_writer.go b/server/webdav/buffered_response_writer.go new file mode 100644 index 0000000000000000000000000000000000000000..ed653eaec393cf2a8ab7afdf4ff35d2cea2a24bb --- /dev/null +++ b/server/webdav/buffered_response_writer.go @@ -0,0 +1,46 @@ +package webdav + +import ( + "net/http" +) + +type bufferedResponseWriter struct { + statusCode int + data []byte + header http.Header +} + +func (w *bufferedResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *bufferedResponseWriter) Write(bytes []byte) (int, error) { + w.data = append(w.data, bytes...) + return len(bytes), nil +} + +func (w *bufferedResponseWriter) WriteHeader(statusCode int) { + if w.statusCode == 0 { + w.statusCode = statusCode + } +} + +func (w *bufferedResponseWriter) WriteToResponse(rw http.ResponseWriter) (int, error) { + h := rw.Header() + for k, vs := range w.header { + for _, v := range vs { + h.Add(k, v) + } + } + rw.WriteHeader(w.statusCode) + return rw.Write(w.data) +} + +func newBufferedResponseWriter() *bufferedResponseWriter { + return &bufferedResponseWriter{ + statusCode: 0, + } +} diff --git a/server/webdav/file.go b/server/webdav/file.go new file mode 100644 index 0000000000000000000000000000000000000000..dcb096296387040e55db7b54a1a496ef40a5aeee --- /dev/null +++ b/server/webdav/file.go @@ -0,0 +1,116 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package webdav + +import ( + "context" + "net/http" + "path" + "path/filepath" + + "github.com/alist-org/alist/v3/internal/conf" + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/op" +) + +// slashClean is equivalent to but slightly more efficient than +// path.Clean("/" + name). +func slashClean(name string) string { + if name == "" || name[0] != '/' { + name = "/" + name + } + return path.Clean(name) +} + +// moveFiles moves files and/or directories from src to dst. +// +// See section 9.9.4 for when various HTTP status codes apply. +func moveFiles(ctx context.Context, src, dst string, overwrite bool) (status int, err error) { + srcDir := path.Dir(src) + dstDir := path.Dir(dst) + srcName := path.Base(src) + dstName := path.Base(dst) + if srcDir == dstDir { + err = fs.Rename(ctx, src, dstName) + } else { + err = fs.Move(ctx, src, dstDir) + if err != nil { + return http.StatusInternalServerError, err + } + if srcName != dstName { + err = fs.Rename(ctx, path.Join(dstDir, srcName), dstName) + } + } + if err != nil { + return http.StatusInternalServerError, err + } + // TODO if there are no files copy, should return 204 + return http.StatusCreated, nil +} + +// copyFiles copies files and/or directories from src to dst. +// +// See section 9.8.5 for when various HTTP status codes apply. +func copyFiles(ctx context.Context, src, dst string, overwrite bool) (status int, err error) { + dstDir := path.Dir(dst) + _, err = fs.Copy(context.WithValue(ctx, conf.NoTaskKey, struct{}{}), src, dstDir, overwrite) + if err != nil { + return http.StatusInternalServerError, err + } + // TODO if there are no files copy, should return 204 + return http.StatusCreated, nil +} + +// walkFS traverses filesystem fs starting at name up to depth levels. +// +// Allowed values for depth are 0, 1 or infiniteDepth. For each visited node, +// walkFS calls walkFn. If a visited file system node is a directory and +// walkFn returns path.SkipDir, walkFS will skip traversal of this node. +func walkFS(ctx context.Context, depth int, name string, info model.Obj, walkFn func(reqPath string, info model.Obj, err error) error) error { + // This implementation is based on Walk's code in the standard path/path package. + err := walkFn(name, info, nil) + if err != nil { + if info.IsDir() && err == filepath.SkipDir { + return nil + } + return err + } + if !info.IsDir() || depth == 0 { + return nil + } + if depth == 1 { + depth = 0 + } + meta, _ := op.GetNearestMeta(name) + // Read directory names. + objs, err := fs.List(context.WithValue(ctx, "meta", meta), name, &fs.ListArgs{}) + //f, err := fs.OpenFile(ctx, name, os.O_RDONLY, 0) + //if err != nil { + // return walkFn(name, info, err) + //} + //fileInfos, err := f.Readdir(0) + //f.Close() + if err != nil { + return walkFn(name, info, err) + } + + for _, fileInfo := range objs { + filename := path.Join(name, fileInfo.GetName()) + if err != nil { + if err := walkFn(filename, fileInfo, err); err != nil && err != filepath.SkipDir { + return err + } + } else { + err = walkFS(ctx, depth, filename, fileInfo, walkFn) + if err != nil { + if !fileInfo.IsDir() || err != filepath.SkipDir { + return err + } + } + } + } + return nil +} diff --git a/server/webdav/if.go b/server/webdav/if.go new file mode 100644 index 0000000000000000000000000000000000000000..416e81cdfddf0fbb8ee5c8231a977fd6aa9c91a0 --- /dev/null +++ b/server/webdav/if.go @@ -0,0 +1,173 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package webdav + +// The If header is covered by Section 10.4. +// http://www.webdav.org/specs/rfc4918.html#HEADER_If + +import ( + "strings" +) + +// ifHeader is a disjunction (OR) of ifLists. +type ifHeader struct { + lists []ifList +} + +// ifList is a conjunction (AND) of Conditions, and an optional resource tag. +type ifList struct { + resourceTag string + conditions []Condition +} + +// parseIfHeader parses the "If: foo bar" HTTP header. The httpHeader string +// should omit the "If:" prefix and have any "\r\n"s collapsed to a " ", as is +// returned by req.Header.Get("If") for a http.Request req. +func parseIfHeader(httpHeader string) (h ifHeader, ok bool) { + s := strings.TrimSpace(httpHeader) + switch tokenType, _, _ := lex(s); tokenType { + case '(': + return parseNoTagLists(s) + case angleTokenType: + return parseTaggedLists(s) + default: + return ifHeader{}, false + } +} + +func parseNoTagLists(s string) (h ifHeader, ok bool) { + for { + l, remaining, ok := parseList(s) + if !ok { + return ifHeader{}, false + } + h.lists = append(h.lists, l) + if remaining == "" { + return h, true + } + s = remaining + } +} + +func parseTaggedLists(s string) (h ifHeader, ok bool) { + resourceTag, n := "", 0 + for first := true; ; first = false { + tokenType, tokenStr, remaining := lex(s) + switch tokenType { + case angleTokenType: + if !first && n == 0 { + return ifHeader{}, false + } + resourceTag, n = tokenStr, 0 + s = remaining + case '(': + n++ + l, remaining, ok := parseList(s) + if !ok { + return ifHeader{}, false + } + l.resourceTag = resourceTag + h.lists = append(h.lists, l) + if remaining == "" { + return h, true + } + s = remaining + default: + return ifHeader{}, false + } + } +} + +func parseList(s string) (l ifList, remaining string, ok bool) { + tokenType, _, s := lex(s) + if tokenType != '(' { + return ifList{}, "", false + } + for { + tokenType, _, remaining = lex(s) + if tokenType == ')' { + if len(l.conditions) == 0 { + return ifList{}, "", false + } + return l, remaining, true + } + c, remaining, ok := parseCondition(s) + if !ok { + return ifList{}, "", false + } + l.conditions = append(l.conditions, c) + s = remaining + } +} + +func parseCondition(s string) (c Condition, remaining string, ok bool) { + tokenType, tokenStr, s := lex(s) + if tokenType == notTokenType { + c.Not = true + tokenType, tokenStr, s = lex(s) + } + switch tokenType { + case strTokenType, angleTokenType: + c.Token = tokenStr + case squareTokenType: + c.ETag = tokenStr + default: + return Condition{}, "", false + } + return c, s, true +} + +// Single-rune tokens like '(' or ')' have a token type equal to their rune. +// All other tokens have a negative token type. +const ( + errTokenType = rune(-1) + eofTokenType = rune(-2) + strTokenType = rune(-3) + notTokenType = rune(-4) + angleTokenType = rune(-5) + squareTokenType = rune(-6) +) + +func lex(s string) (tokenType rune, tokenStr string, remaining string) { + // The net/textproto Reader that parses the HTTP header will collapse + // Linear White Space that spans multiple "\r\n" lines to a single " ", + // so we don't need to look for '\r' or '\n'. + for len(s) > 0 && (s[0] == '\t' || s[0] == ' ') { + s = s[1:] + } + if len(s) == 0 { + return eofTokenType, "", "" + } + i := 0 +loop: + for ; i < len(s); i++ { + switch s[i] { + case '\t', ' ', '(', ')', '<', '>', '[', ']': + break loop + } + } + + if i != 0 { + tokenStr, remaining = s[:i], s[i:] + if tokenStr == "Not" { + return notTokenType, "", remaining + } + return strTokenType, tokenStr, remaining + } + + j := 0 + switch s[0] { + case '<': + j, tokenType = strings.IndexByte(s, '>'), angleTokenType + case '[': + j, tokenType = strings.IndexByte(s, ']'), squareTokenType + default: + return rune(s[0]), "", s[1:] + } + if j < 0 { + return errTokenType, "", "" + } + return tokenType, s[1:j], s[j+1:] +} diff --git a/server/webdav/internal/xml/README b/server/webdav/internal/xml/README new file mode 100644 index 0000000000000000000000000000000000000000..89656f4896254884a8166dca32a3fa16363c25c3 --- /dev/null +++ b/server/webdav/internal/xml/README @@ -0,0 +1,11 @@ +This is a fork of the encoding/xml package at ca1d6c4, the last commit before +https://go.googlesource.com/go/+/c0d6d33 "encoding/xml: restore Go 1.4 name +space behavior" made late in the lead-up to the Go 1.5 release. + +The list of encoding/xml changes is at +https://go.googlesource.com/go/+log/master/src/encoding/xml + +This fork is temporary, and I (nigeltao) expect to revert it after Go 1.6 is +released. + +See http://golang.org/issue/11841 diff --git a/server/webdav/internal/xml/atom_test.go b/server/webdav/internal/xml/atom_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a71284312af732cfaaba3233c889d1311027b483 --- /dev/null +++ b/server/webdav/internal/xml/atom_test.go @@ -0,0 +1,56 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import "time" + +var atomValue = &Feed{ + XMLName: Name{"http://www.w3.org/2005/Atom", "feed"}, + Title: "Example Feed", + Link: []Link{{Href: "http://example.org/"}}, + Updated: ParseTime("2003-12-13T18:30:02Z"), + Author: Person{Name: "John Doe"}, + Id: "urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6", + + Entry: []Entry{ + { + Title: "Atom-Powered Robots Run Amok", + Link: []Link{{Href: "http://example.org/2003/12/13/atom03"}}, + Id: "urn:uuid:1225c695-cfb8-4ebb-aaaa-80da344efa6a", + Updated: ParseTime("2003-12-13T18:30:02Z"), + Summary: NewText("Some text."), + }, + }, +} + +var atomXml = `` + + `` + + `Example Feed` + + `urn:uuid:60a76c80-d399-11d9-b93C-0003939e0af6` + + `` + + `John Doe` + + `` + + `Atom-Powered Robots Run Amok` + + `urn:uuid:1225c695-cfb8-4ebb-aaaa-80da344efa6a` + + `` + + `2003-12-13T18:30:02Z` + + `` + + `Some text.` + + `` + + `` + +func ParseTime(str string) time.Time { + t, err := time.Parse(time.RFC3339, str) + if err != nil { + panic(err) + } + return t +} + +func NewText(text string) Text { + return Text{ + Body: text, + } +} diff --git a/server/webdav/internal/xml/example_test.go b/server/webdav/internal/xml/example_test.go new file mode 100644 index 0000000000000000000000000000000000000000..21b48dea534ef46cf55161ff54aff23e58dc61f6 --- /dev/null +++ b/server/webdav/internal/xml/example_test.go @@ -0,0 +1,151 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml_test + +import ( + "encoding/xml" + "fmt" + "os" +) + +func ExampleMarshalIndent() { + type Address struct { + City, State string + } + type Person struct { + XMLName xml.Name `xml:"person"` + Id int `xml:"id,attr"` + FirstName string `xml:"name>first"` + LastName string `xml:"name>last"` + Age int `xml:"age"` + Height float32 `xml:"height,omitempty"` + Married bool + Address + Comment string `xml:",comment"` + } + + v := &Person{Id: 13, FirstName: "John", LastName: "Doe", Age: 42} + v.Comment = " Need more details. " + v.Address = Address{"Hanga Roa", "Easter Island"} + + output, err := xml.MarshalIndent(v, " ", " ") + if err != nil { + fmt.Printf("error: %v\n", err) + } + + os.Stdout.Write(output) + // Output: + // + // + // John + // Doe + // + // 42 + // false + // Hanga Roa + // Easter Island + // + // +} + +func ExampleEncoder() { + type Address struct { + City, State string + } + type Person struct { + XMLName xml.Name `xml:"person"` + Id int `xml:"id,attr"` + FirstName string `xml:"name>first"` + LastName string `xml:"name>last"` + Age int `xml:"age"` + Height float32 `xml:"height,omitempty"` + Married bool + Address + Comment string `xml:",comment"` + } + + v := &Person{Id: 13, FirstName: "John", LastName: "Doe", Age: 42} + v.Comment = " Need more details. " + v.Address = Address{"Hanga Roa", "Easter Island"} + + enc := xml.NewEncoder(os.Stdout) + enc.Indent(" ", " ") + if err := enc.Encode(v); err != nil { + fmt.Printf("error: %v\n", err) + } + + // Output: + // + // + // John + // Doe + // + // 42 + // false + // Hanga Roa + // Easter Island + // + // +} + +// This example demonstrates unmarshaling an XML excerpt into a value with +// some preset fields. Note that the Phone field isn't modified and that +// the XML element is ignored. Also, the Groups field is assigned +// considering the element path provided in its tag. +func ExampleUnmarshal() { + type Email struct { + Where string `xml:"where,attr"` + Addr string + } + type Address struct { + City, State string + } + type Result struct { + XMLName xml.Name `xml:"Person"` + Name string `xml:"FullName"` + Phone string + Email []Email + Groups []string `xml:"Group>Value"` + Address + } + v := Result{Name: "none", Phone: "none"} + + data := ` + + Grace R. Emlin + Example Inc. + + gre@example.com + + + gre@work.com + + + Friends + Squash + + Hanga Roa + Easter Island + + ` + err := xml.Unmarshal([]byte(data), &v) + if err != nil { + fmt.Printf("error: %v", err) + return + } + fmt.Printf("XMLName: %#v\n", v.XMLName) + fmt.Printf("Name: %q\n", v.Name) + fmt.Printf("Phone: %q\n", v.Phone) + fmt.Printf("Email: %v\n", v.Email) + fmt.Printf("Groups: %v\n", v.Groups) + fmt.Printf("Address: %v\n", v.Address) + // Output: + // XMLName: xml.Name{Space:"", Local:"Person"} + // Name: "Grace R. Emlin" + // Phone: "none" + // Email: [{home gre@example.com} {work gre@work.com}] + // Groups: [Friends Squash] + // Address: {Hanga Roa Easter Island} +} diff --git a/server/webdav/internal/xml/marshal.go b/server/webdav/internal/xml/marshal.go new file mode 100644 index 0000000000000000000000000000000000000000..4dd0f417fd1153b523f722904873e3eee8edb957 --- /dev/null +++ b/server/webdav/internal/xml/marshal.go @@ -0,0 +1,1223 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import ( + "bufio" + "bytes" + "encoding" + "fmt" + "io" + "reflect" + "strconv" + "strings" +) + +const ( + // A generic XML header suitable for use with the output of Marshal. + // This is not automatically added to any output of this package, + // it is provided as a convenience. + Header = `` + "\n" +) + +// Marshal returns the XML encoding of v. +// +// Marshal handles an array or slice by marshalling each of the elements. +// Marshal handles a pointer by marshalling the value it points at or, if the +// pointer is nil, by writing nothing. Marshal handles an interface value by +// marshalling the value it contains or, if the interface value is nil, by +// writing nothing. Marshal handles all other data by writing one or more XML +// elements containing the data. +// +// The name for the XML elements is taken from, in order of preference: +// - the tag on the XMLName field, if the data is a struct +// - the value of the XMLName field of type xml.Name +// - the tag of the struct field used to obtain the data +// - the name of the struct field used to obtain the data +// - the name of the marshalled type +// +// The XML element for a struct contains marshalled elements for each of the +// exported fields of the struct, with these exceptions: +// - the XMLName field, described above, is omitted. +// - a field with tag "-" is omitted. +// - a field with tag "name,attr" becomes an attribute with +// the given name in the XML element. +// - a field with tag ",attr" becomes an attribute with the +// field name in the XML element. +// - a field with tag ",chardata" is written as character data, +// not as an XML element. +// - a field with tag ",innerxml" is written verbatim, not subject +// to the usual marshalling procedure. +// - a field with tag ",comment" is written as an XML comment, not +// subject to the usual marshalling procedure. It must not contain +// the "--" string within it. +// - a field with a tag including the "omitempty" option is omitted +// if the field value is empty. The empty values are false, 0, any +// nil pointer or interface value, and any array, slice, map, or +// string of length zero. +// - an anonymous struct field is handled as if the fields of its +// value were part of the outer struct. +// +// If a field uses a tag "a>b>c", then the element c will be nested inside +// parent elements a and b. Fields that appear next to each other that name +// the same parent will be enclosed in one XML element. +// +// See MarshalIndent for an example. +// +// Marshal will return an error if asked to marshal a channel, function, or map. +func Marshal(v interface{}) ([]byte, error) { + var b bytes.Buffer + if err := NewEncoder(&b).Encode(v); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// Marshaler is the interface implemented by objects that can marshal +// themselves into valid XML elements. +// +// MarshalXML encodes the receiver as zero or more XML elements. +// By convention, arrays or slices are typically encoded as a sequence +// of elements, one per entry. +// Using start as the element tag is not required, but doing so +// will enable Unmarshal to match the XML elements to the correct +// struct field. +// One common implementation strategy is to construct a separate +// value with a layout corresponding to the desired XML and then +// to encode it using e.EncodeElement. +// Another common strategy is to use repeated calls to e.EncodeToken +// to generate the XML output one token at a time. +// The sequence of encoded tokens must make up zero or more valid +// XML elements. +type Marshaler interface { + MarshalXML(e *Encoder, start StartElement) error +} + +// MarshalerAttr is the interface implemented by objects that can marshal +// themselves into valid XML attributes. +// +// MarshalXMLAttr returns an XML attribute with the encoded value of the receiver. +// Using name as the attribute name is not required, but doing so +// will enable Unmarshal to match the attribute to the correct +// struct field. +// If MarshalXMLAttr returns the zero attribute Attr{}, no attribute +// will be generated in the output. +// MarshalXMLAttr is used only for struct fields with the +// "attr" option in the field tag. +type MarshalerAttr interface { + MarshalXMLAttr(name Name) (Attr, error) +} + +// MarshalIndent works like Marshal, but each XML element begins on a new +// indented line that starts with prefix and is followed by one or more +// copies of indent according to the nesting depth. +func MarshalIndent(v interface{}, prefix, indent string) ([]byte, error) { + var b bytes.Buffer + enc := NewEncoder(&b) + enc.Indent(prefix, indent) + if err := enc.Encode(v); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +// An Encoder writes XML data to an output stream. +type Encoder struct { + p printer +} + +// NewEncoder returns a new encoder that writes to w. +func NewEncoder(w io.Writer) *Encoder { + e := &Encoder{printer{Writer: bufio.NewWriter(w)}} + e.p.encoder = e + return e +} + +// Indent sets the encoder to generate XML in which each element +// begins on a new indented line that starts with prefix and is followed by +// one or more copies of indent according to the nesting depth. +func (enc *Encoder) Indent(prefix, indent string) { + enc.p.prefix = prefix + enc.p.indent = indent +} + +// Encode writes the XML encoding of v to the stream. +// +// See the documentation for Marshal for details about the conversion +// of Go values to XML. +// +// Encode calls Flush before returning. +func (enc *Encoder) Encode(v interface{}) error { + err := enc.p.marshalValue(reflect.ValueOf(v), nil, nil) + if err != nil { + return err + } + return enc.p.Flush() +} + +// EncodeElement writes the XML encoding of v to the stream, +// using start as the outermost tag in the encoding. +// +// See the documentation for Marshal for details about the conversion +// of Go values to XML. +// +// EncodeElement calls Flush before returning. +func (enc *Encoder) EncodeElement(v interface{}, start StartElement) error { + err := enc.p.marshalValue(reflect.ValueOf(v), nil, &start) + if err != nil { + return err + } + return enc.p.Flush() +} + +var ( + begComment = []byte("") + endProcInst = []byte("?>") + endDirective = []byte(">") +) + +// EncodeToken writes the given XML token to the stream. +// It returns an error if StartElement and EndElement tokens are not +// properly matched. +// +// EncodeToken does not call Flush, because usually it is part of a +// larger operation such as Encode or EncodeElement (or a custom +// Marshaler's MarshalXML invoked during those), and those will call +// Flush when finished. Callers that create an Encoder and then invoke +// EncodeToken directly, without using Encode or EncodeElement, need to +// call Flush when finished to ensure that the XML is written to the +// underlying writer. +// +// EncodeToken allows writing a ProcInst with Target set to "xml" only +// as the first token in the stream. +// +// When encoding a StartElement holding an XML namespace prefix +// declaration for a prefix that is not already declared, contained +// elements (including the StartElement itself) will use the declared +// prefix when encoding names with matching namespace URIs. +func (enc *Encoder) EncodeToken(t Token) error { + + p := &enc.p + switch t := t.(type) { + case StartElement: + if err := p.writeStart(&t); err != nil { + return err + } + case EndElement: + if err := p.writeEnd(t.Name); err != nil { + return err + } + case CharData: + escapeText(p, t, false) + case Comment: + if bytes.Contains(t, endComment) { + return fmt.Errorf("xml: EncodeToken of Comment containing --> marker") + } + p.WriteString("") + return p.cachedWriteError() + case ProcInst: + // First token to be encoded which is also a ProcInst with target of xml + // is the xml declaration. The only ProcInst where target of xml is allowed. + if t.Target == "xml" && p.Buffered() != 0 { + return fmt.Errorf("xml: EncodeToken of ProcInst xml target only valid for xml declaration, first token encoded") + } + if !isNameString(t.Target) { + return fmt.Errorf("xml: EncodeToken of ProcInst with invalid Target") + } + if bytes.Contains(t.Inst, endProcInst) { + return fmt.Errorf("xml: EncodeToken of ProcInst containing ?> marker") + } + p.WriteString(" 0 { + p.WriteByte(' ') + p.Write(t.Inst) + } + p.WriteString("?>") + case Directive: + if !isValidDirective(t) { + return fmt.Errorf("xml: EncodeToken of Directive containing wrong < or > markers") + } + p.WriteString("") + default: + return fmt.Errorf("xml: EncodeToken of invalid token type") + + } + return p.cachedWriteError() +} + +// isValidDirective reports whether dir is a valid directive text, +// meaning angle brackets are matched, ignoring comments and strings. +func isValidDirective(dir Directive) bool { + var ( + depth int + inquote uint8 + incomment bool + ) + for i, c := range dir { + switch { + case incomment: + if c == '>' { + if n := 1 + i - len(endComment); n >= 0 && bytes.Equal(dir[n:i+1], endComment) { + incomment = false + } + } + // Just ignore anything in comment + case inquote != 0: + if c == inquote { + inquote = 0 + } + // Just ignore anything within quotes + case c == '\'' || c == '"': + inquote = c + case c == '<': + if i+len(begComment) < len(dir) && bytes.Equal(dir[i:i+len(begComment)], begComment) { + incomment = true + } else { + depth++ + } + case c == '>': + if depth == 0 { + return false + } + depth-- + } + } + return depth == 0 && inquote == 0 && !incomment +} + +// Flush flushes any buffered XML to the underlying writer. +// See the EncodeToken documentation for details about when it is necessary. +func (enc *Encoder) Flush() error { + return enc.p.Flush() +} + +type printer struct { + *bufio.Writer + encoder *Encoder + seq int + indent string + prefix string + depth int + indentedIn bool + putNewline bool + defaultNS string + attrNS map[string]string // map prefix -> name space + attrPrefix map[string]string // map name space -> prefix + prefixes []printerPrefix + tags []Name +} + +// printerPrefix holds a namespace undo record. +// When an element is popped, the prefix record +// is set back to the recorded URL. The empty +// prefix records the URL for the default name space. +// +// The start of an element is recorded with an element +// that has mark=true. +type printerPrefix struct { + prefix string + url string + mark bool +} + +func (p *printer) prefixForNS(url string, isAttr bool) string { + // The "http://www.w3.org/XML/1998/namespace" name space is predefined as "xml" + // and must be referred to that way. + // (The "http://www.w3.org/2000/xmlns/" name space is also predefined as "xmlns", + // but users should not be trying to use that one directly - that's our job.) + if url == xmlURL { + return "xml" + } + if !isAttr && url == p.defaultNS { + // We can use the default name space. + return "" + } + return p.attrPrefix[url] +} + +// defineNS pushes any namespace definition found in the given attribute. +// If ignoreNonEmptyDefault is true, an xmlns="nonempty" +// attribute will be ignored. +func (p *printer) defineNS(attr Attr, ignoreNonEmptyDefault bool) error { + var prefix string + if attr.Name.Local == "xmlns" { + if attr.Name.Space != "" && attr.Name.Space != "xml" && attr.Name.Space != xmlURL { + return fmt.Errorf("xml: cannot redefine xmlns attribute prefix") + } + } else if attr.Name.Space == "xmlns" && attr.Name.Local != "" { + prefix = attr.Name.Local + if attr.Value == "" { + // Technically, an empty XML namespace is allowed for an attribute. + // From http://www.w3.org/TR/xml-names11/#scoping-defaulting: + // + // The attribute value in a namespace declaration for a prefix may be + // empty. This has the effect, within the scope of the declaration, of removing + // any association of the prefix with a namespace name. + // + // However our namespace prefixes here are used only as hints. There's + // no need to respect the removal of a namespace prefix, so we ignore it. + return nil + } + } else { + // Ignore: it's not a namespace definition + return nil + } + if prefix == "" { + if attr.Value == p.defaultNS { + // No need for redefinition. + return nil + } + if attr.Value != "" && ignoreNonEmptyDefault { + // We have an xmlns="..." value but + // it can't define a name space in this context, + // probably because the element has an empty + // name space. In this case, we just ignore + // the name space declaration. + return nil + } + } else if _, ok := p.attrPrefix[attr.Value]; ok { + // There's already a prefix for the given name space, + // so use that. This prevents us from + // having two prefixes for the same name space + // so attrNS and attrPrefix can remain bijective. + return nil + } + p.pushPrefix(prefix, attr.Value) + return nil +} + +// createNSPrefix creates a name space prefix attribute +// to use for the given name space, defining a new prefix +// if necessary. +// If isAttr is true, the prefix is to be created for an attribute +// prefix, which means that the default name space cannot +// be used. +func (p *printer) createNSPrefix(url string, isAttr bool) { + if _, ok := p.attrPrefix[url]; ok { + // We already have a prefix for the given URL. + return + } + switch { + case !isAttr && url == p.defaultNS: + // We can use the default name space. + return + case url == "": + // The only way we can encode names in the empty + // name space is by using the default name space, + // so we must use that. + if p.defaultNS != "" { + // The default namespace is non-empty, so we + // need to set it to empty. + p.pushPrefix("", "") + } + return + case url == xmlURL: + return + } + // TODO If the URL is an existing prefix, we could + // use it as is. That would enable the + // marshaling of elements that had been unmarshaled + // and with a name space prefix that was not found. + // although technically it would be incorrect. + + // Pick a name. We try to use the final element of the path + // but fall back to _. + prefix := strings.TrimRight(url, "/") + if i := strings.LastIndex(prefix, "/"); i >= 0 { + prefix = prefix[i+1:] + } + if prefix == "" || !isName([]byte(prefix)) || strings.Contains(prefix, ":") { + prefix = "_" + } + if strings.HasPrefix(prefix, "xml") { + // xmlanything is reserved. + prefix = "_" + prefix + } + if p.attrNS[prefix] != "" { + // Name is taken. Find a better one. + for p.seq++; ; p.seq++ { + if id := prefix + "_" + strconv.Itoa(p.seq); p.attrNS[id] == "" { + prefix = id + break + } + } + } + + p.pushPrefix(prefix, url) +} + +// writeNamespaces writes xmlns attributes for all the +// namespace prefixes that have been defined in +// the current element. +func (p *printer) writeNamespaces() { + for i := len(p.prefixes) - 1; i >= 0; i-- { + prefix := p.prefixes[i] + if prefix.mark { + return + } + p.WriteString(" ") + if prefix.prefix == "" { + // Default name space. + p.WriteString(`xmlns="`) + } else { + p.WriteString("xmlns:") + p.WriteString(prefix.prefix) + p.WriteString(`="`) + } + EscapeText(p, []byte(p.nsForPrefix(prefix.prefix))) + p.WriteString(`"`) + } +} + +// pushPrefix pushes a new prefix on the prefix stack +// without checking to see if it is already defined. +func (p *printer) pushPrefix(prefix, url string) { + p.prefixes = append(p.prefixes, printerPrefix{ + prefix: prefix, + url: p.nsForPrefix(prefix), + }) + p.setAttrPrefix(prefix, url) +} + +// nsForPrefix returns the name space for the given +// prefix. Note that this is not valid for the +// empty attribute prefix, which always has an empty +// name space. +func (p *printer) nsForPrefix(prefix string) string { + if prefix == "" { + return p.defaultNS + } + return p.attrNS[prefix] +} + +// markPrefix marks the start of an element on the prefix +// stack. +func (p *printer) markPrefix() { + p.prefixes = append(p.prefixes, printerPrefix{ + mark: true, + }) +} + +// popPrefix pops all defined prefixes for the current +// element. +func (p *printer) popPrefix() { + for len(p.prefixes) > 0 { + prefix := p.prefixes[len(p.prefixes)-1] + p.prefixes = p.prefixes[:len(p.prefixes)-1] + if prefix.mark { + break + } + p.setAttrPrefix(prefix.prefix, prefix.url) + } +} + +// setAttrPrefix sets an attribute name space prefix. +// If url is empty, the attribute is removed. +// If prefix is empty, the default name space is set. +func (p *printer) setAttrPrefix(prefix, url string) { + if prefix == "" { + p.defaultNS = url + return + } + if url == "" { + delete(p.attrPrefix, p.attrNS[prefix]) + delete(p.attrNS, prefix) + return + } + if p.attrPrefix == nil { + // Need to define a new name space. + p.attrPrefix = make(map[string]string) + p.attrNS = make(map[string]string) + } + // Remove any old prefix value. This is OK because we maintain a + // strict one-to-one mapping between prefix and URL (see + // defineNS) + delete(p.attrPrefix, p.attrNS[prefix]) + p.attrPrefix[url] = prefix + p.attrNS[prefix] = url +} + +var ( + marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() + marshalerAttrType = reflect.TypeOf((*MarshalerAttr)(nil)).Elem() + textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() +) + +// marshalValue writes one or more XML elements representing val. +// If val was obtained from a struct field, finfo must have its details. +func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplate *StartElement) error { + if startTemplate != nil && startTemplate.Name.Local == "" { + return fmt.Errorf("xml: EncodeElement of StartElement with missing name") + } + + if !val.IsValid() { + return nil + } + if finfo != nil && finfo.flags&fOmitEmpty != 0 && isEmptyValue(val) { + return nil + } + + // Drill into interfaces and pointers. + // This can turn into an infinite loop given a cyclic chain, + // but it matches the Go 1 behavior. + for val.Kind() == reflect.Interface || val.Kind() == reflect.Ptr { + if val.IsNil() { + return nil + } + val = val.Elem() + } + + kind := val.Kind() + typ := val.Type() + + // Check for marshaler. + if val.CanInterface() && typ.Implements(marshalerType) { + return p.marshalInterface(val.Interface().(Marshaler), p.defaultStart(typ, finfo, startTemplate)) + } + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(marshalerType) { + return p.marshalInterface(pv.Interface().(Marshaler), p.defaultStart(pv.Type(), finfo, startTemplate)) + } + } + + // Check for text marshaler. + if val.CanInterface() && typ.Implements(textMarshalerType) { + return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), p.defaultStart(typ, finfo, startTemplate)) + } + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { + return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), p.defaultStart(pv.Type(), finfo, startTemplate)) + } + } + + // Slices and arrays iterate over the elements. They do not have an enclosing tag. + if (kind == reflect.Slice || kind == reflect.Array) && typ.Elem().Kind() != reflect.Uint8 { + for i, n := 0, val.Len(); i < n; i++ { + if err := p.marshalValue(val.Index(i), finfo, startTemplate); err != nil { + return err + } + } + return nil + } + + tinfo, err := getTypeInfo(typ) + if err != nil { + return err + } + + // Create start element. + // Precedence for the XML element name is: + // 0. startTemplate + // 1. XMLName field in underlying struct; + // 2. field name/tag in the struct field; and + // 3. type name + var start StartElement + + // explicitNS records whether the element's name space has been + // explicitly set (for example an XMLName field). + explicitNS := false + + if startTemplate != nil { + start.Name = startTemplate.Name + explicitNS = true + start.Attr = append(start.Attr, startTemplate.Attr...) + } else if tinfo.xmlname != nil { + xmlname := tinfo.xmlname + if xmlname.name != "" { + start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name + } else if v, ok := xmlname.value(val).Interface().(Name); ok && v.Local != "" { + start.Name = v + } + explicitNS = true + } + if start.Name.Local == "" && finfo != nil { + start.Name.Local = finfo.name + if finfo.xmlns != "" { + start.Name.Space = finfo.xmlns + explicitNS = true + } + } + if start.Name.Local == "" { + name := typ.Name() + if name == "" { + return &UnsupportedTypeError{typ} + } + start.Name.Local = name + } + + // defaultNS records the default name space as set by a xmlns="..." + // attribute. We don't set p.defaultNS because we want to let + // the attribute writing code (in p.defineNS) be solely responsible + // for maintaining that. + defaultNS := p.defaultNS + + // Attributes + for i := range tinfo.fields { + finfo := &tinfo.fields[i] + if finfo.flags&fAttr == 0 { + continue + } + attr, err := p.fieldAttr(finfo, val) + if err != nil { + return err + } + if attr.Name.Local == "" { + continue + } + start.Attr = append(start.Attr, attr) + if attr.Name.Space == "" && attr.Name.Local == "xmlns" { + defaultNS = attr.Value + } + } + if !explicitNS { + // Historic behavior: elements use the default name space + // they are contained in by default. + start.Name.Space = defaultNS + } + // Historic behaviour: an element that's in a namespace sets + // the default namespace for all elements contained within it. + start.setDefaultNamespace() + + if err := p.writeStart(&start); err != nil { + return err + } + + if val.Kind() == reflect.Struct { + err = p.marshalStruct(tinfo, val) + } else { + s, b, err1 := p.marshalSimple(typ, val) + if err1 != nil { + err = err1 + } else if b != nil { + EscapeText(p, b) + } else { + p.EscapeString(s) + } + } + if err != nil { + return err + } + + if err := p.writeEnd(start.Name); err != nil { + return err + } + + return p.cachedWriteError() +} + +// fieldAttr returns the attribute of the given field. +// If the returned attribute has an empty Name.Local, +// it should not be used. +// The given value holds the value containing the field. +func (p *printer) fieldAttr(finfo *fieldInfo, val reflect.Value) (Attr, error) { + fv := finfo.value(val) + name := Name{Space: finfo.xmlns, Local: finfo.name} + if finfo.flags&fOmitEmpty != 0 && isEmptyValue(fv) { + return Attr{}, nil + } + if fv.Kind() == reflect.Interface && fv.IsNil() { + return Attr{}, nil + } + if fv.CanInterface() && fv.Type().Implements(marshalerAttrType) { + attr, err := fv.Interface().(MarshalerAttr).MarshalXMLAttr(name) + return attr, err + } + if fv.CanAddr() { + pv := fv.Addr() + if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) { + attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name) + return attr, err + } + } + if fv.CanInterface() && fv.Type().Implements(textMarshalerType) { + text, err := fv.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return Attr{}, err + } + return Attr{name, string(text)}, nil + } + if fv.CanAddr() { + pv := fv.Addr() + if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { + text, err := pv.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return Attr{}, err + } + return Attr{name, string(text)}, nil + } + } + // Dereference or skip nil pointer, interface values. + switch fv.Kind() { + case reflect.Ptr, reflect.Interface: + if fv.IsNil() { + return Attr{}, nil + } + fv = fv.Elem() + } + s, b, err := p.marshalSimple(fv.Type(), fv) + if err != nil { + return Attr{}, err + } + if b != nil { + s = string(b) + } + return Attr{name, s}, nil +} + +// defaultStart returns the default start element to use, +// given the reflect type, field info, and start template. +func (p *printer) defaultStart(typ reflect.Type, finfo *fieldInfo, startTemplate *StartElement) StartElement { + var start StartElement + // Precedence for the XML element name is as above, + // except that we do not look inside structs for the first field. + if startTemplate != nil { + start.Name = startTemplate.Name + start.Attr = append(start.Attr, startTemplate.Attr...) + } else if finfo != nil && finfo.name != "" { + start.Name.Local = finfo.name + start.Name.Space = finfo.xmlns + } else if typ.Name() != "" { + start.Name.Local = typ.Name() + } else { + // Must be a pointer to a named type, + // since it has the Marshaler methods. + start.Name.Local = typ.Elem().Name() + } + // Historic behaviour: elements use the name space of + // the element they are contained in by default. + if start.Name.Space == "" { + start.Name.Space = p.defaultNS + } + start.setDefaultNamespace() + return start +} + +// marshalInterface marshals a Marshaler interface value. +func (p *printer) marshalInterface(val Marshaler, start StartElement) error { + // Push a marker onto the tag stack so that MarshalXML + // cannot close the XML tags that it did not open. + p.tags = append(p.tags, Name{}) + n := len(p.tags) + + err := val.MarshalXML(p.encoder, start) + if err != nil { + return err + } + + // Make sure MarshalXML closed all its tags. p.tags[n-1] is the mark. + if len(p.tags) > n { + return fmt.Errorf("xml: %s.MarshalXML wrote invalid XML: <%s> not closed", receiverType(val), p.tags[len(p.tags)-1].Local) + } + p.tags = p.tags[:n-1] + return nil +} + +// marshalTextInterface marshals a TextMarshaler interface value. +func (p *printer) marshalTextInterface(val encoding.TextMarshaler, start StartElement) error { + if err := p.writeStart(&start); err != nil { + return err + } + text, err := val.MarshalText() + if err != nil { + return err + } + EscapeText(p, text) + return p.writeEnd(start.Name) +} + +// writeStart writes the given start element. +func (p *printer) writeStart(start *StartElement) error { + if start.Name.Local == "" { + return fmt.Errorf("xml: start tag with no name") + } + + p.tags = append(p.tags, start.Name) + p.markPrefix() + // Define any name spaces explicitly declared in the attributes. + // We do this as a separate pass so that explicitly declared prefixes + // will take precedence over implicitly declared prefixes + // regardless of the order of the attributes. + ignoreNonEmptyDefault := start.Name.Space == "" + for _, attr := range start.Attr { + if err := p.defineNS(attr, ignoreNonEmptyDefault); err != nil { + return err + } + } + // Define any new name spaces implied by the attributes. + for _, attr := range start.Attr { + name := attr.Name + // From http://www.w3.org/TR/xml-names11/#defaulting + // "Default namespace declarations do not apply directly + // to attribute names; the interpretation of unprefixed + // attributes is determined by the element on which they + // appear." + // This means we don't need to create a new namespace + // when an attribute name space is empty. + if name.Space != "" && !name.isNamespace() { + p.createNSPrefix(name.Space, true) + } + } + p.createNSPrefix(start.Name.Space, false) + + p.writeIndent(1) + p.WriteByte('<') + p.writeName(start.Name, false) + p.writeNamespaces() + for _, attr := range start.Attr { + name := attr.Name + if name.Local == "" || name.isNamespace() { + // Namespaces have already been written by writeNamespaces above. + continue + } + p.WriteByte(' ') + p.writeName(name, true) + p.WriteString(`="`) + p.EscapeString(attr.Value) + p.WriteByte('"') + } + p.WriteByte('>') + return nil +} + +// writeName writes the given name. It assumes +// that p.createNSPrefix(name) has already been called. +func (p *printer) writeName(name Name, isAttr bool) { + if prefix := p.prefixForNS(name.Space, isAttr); prefix != "" { + p.WriteString(prefix) + p.WriteByte(':') + } + p.WriteString(name.Local) +} + +func (p *printer) writeEnd(name Name) error { + if name.Local == "" { + return fmt.Errorf("xml: end tag with no name") + } + if len(p.tags) == 0 || p.tags[len(p.tags)-1].Local == "" { + return fmt.Errorf("xml: end tag without start tag", name.Local) + } + if top := p.tags[len(p.tags)-1]; top != name { + if top.Local != name.Local { + return fmt.Errorf("xml: end tag does not match start tag <%s>", name.Local, top.Local) + } + return fmt.Errorf("xml: end tag in namespace %s does not match start tag <%s> in namespace %s", name.Local, name.Space, top.Local, top.Space) + } + p.tags = p.tags[:len(p.tags)-1] + + p.writeIndent(-1) + p.WriteByte('<') + p.WriteByte('/') + p.writeName(name, false) + p.WriteByte('>') + p.popPrefix() + return nil +} + +func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) (string, []byte, error) { + switch val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(val.Int(), 10), nil, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return strconv.FormatUint(val.Uint(), 10), nil, nil + case reflect.Float32, reflect.Float64: + return strconv.FormatFloat(val.Float(), 'g', -1, val.Type().Bits()), nil, nil + case reflect.String: + return val.String(), nil, nil + case reflect.Bool: + return strconv.FormatBool(val.Bool()), nil, nil + case reflect.Array: + if typ.Elem().Kind() != reflect.Uint8 { + break + } + // [...]byte + var bytes []byte + if val.CanAddr() { + bytes = val.Slice(0, val.Len()).Bytes() + } else { + bytes = make([]byte, val.Len()) + reflect.Copy(reflect.ValueOf(bytes), val) + } + return "", bytes, nil + case reflect.Slice: + if typ.Elem().Kind() != reflect.Uint8 { + break + } + // []byte + return "", val.Bytes(), nil + } + return "", nil, &UnsupportedTypeError{typ} +} + +var ddBytes = []byte("--") + +func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { + s := parentStack{p: p} + for i := range tinfo.fields { + finfo := &tinfo.fields[i] + if finfo.flags&fAttr != 0 { + continue + } + vf := finfo.value(val) + + // Dereference or skip nil pointer, interface values. + switch vf.Kind() { + case reflect.Ptr, reflect.Interface: + if !vf.IsNil() { + vf = vf.Elem() + } + } + + switch finfo.flags & fMode { + case fCharData: + if err := s.setParents(&noField, reflect.Value{}); err != nil { + return err + } + if vf.CanInterface() && vf.Type().Implements(textMarshalerType) { + data, err := vf.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return err + } + Escape(p, data) + continue + } + if vf.CanAddr() { + pv := vf.Addr() + if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { + data, err := pv.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return err + } + Escape(p, data) + continue + } + } + var scratch [64]byte + switch vf.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + Escape(p, strconv.AppendInt(scratch[:0], vf.Int(), 10)) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + Escape(p, strconv.AppendUint(scratch[:0], vf.Uint(), 10)) + case reflect.Float32, reflect.Float64: + Escape(p, strconv.AppendFloat(scratch[:0], vf.Float(), 'g', -1, vf.Type().Bits())) + case reflect.Bool: + Escape(p, strconv.AppendBool(scratch[:0], vf.Bool())) + case reflect.String: + if err := EscapeText(p, []byte(vf.String())); err != nil { + return err + } + case reflect.Slice: + if elem, ok := vf.Interface().([]byte); ok { + if err := EscapeText(p, elem); err != nil { + return err + } + } + } + continue + + case fComment: + if err := s.setParents(&noField, reflect.Value{}); err != nil { + return err + } + k := vf.Kind() + if !(k == reflect.String || k == reflect.Slice && vf.Type().Elem().Kind() == reflect.Uint8) { + return fmt.Errorf("xml: bad type for comment field of %s", val.Type()) + } + if vf.Len() == 0 { + continue + } + p.writeIndent(0) + p.WriteString("" is invalid grammar. Make it "- -->" + p.WriteByte(' ') + } + p.WriteString("-->") + continue + + case fInnerXml: + iface := vf.Interface() + switch raw := iface.(type) { + case []byte: + p.Write(raw) + continue + case string: + p.WriteString(raw) + continue + } + + case fElement, fElement | fAny: + if err := s.setParents(finfo, vf); err != nil { + return err + } + } + if err := p.marshalValue(vf, finfo, nil); err != nil { + return err + } + } + if err := s.setParents(&noField, reflect.Value{}); err != nil { + return err + } + return p.cachedWriteError() +} + +var noField fieldInfo + +// return the bufio Writer's cached write error +func (p *printer) cachedWriteError() error { + _, err := p.Write(nil) + return err +} + +func (p *printer) writeIndent(depthDelta int) { + if len(p.prefix) == 0 && len(p.indent) == 0 { + return + } + if depthDelta < 0 { + p.depth-- + if p.indentedIn { + p.indentedIn = false + return + } + p.indentedIn = false + } + if p.putNewline { + p.WriteByte('\n') + } else { + p.putNewline = true + } + if len(p.prefix) > 0 { + p.WriteString(p.prefix) + } + if len(p.indent) > 0 { + for i := 0; i < p.depth; i++ { + p.WriteString(p.indent) + } + } + if depthDelta > 0 { + p.depth++ + p.indentedIn = true + } +} + +type parentStack struct { + p *printer + xmlns string + parents []string +} + +// setParents sets the stack of current parents to those found in finfo. +// It only writes the start elements if vf holds a non-nil value. +// If finfo is &noField, it pops all elements. +func (s *parentStack) setParents(finfo *fieldInfo, vf reflect.Value) error { + xmlns := s.p.defaultNS + if finfo.xmlns != "" { + xmlns = finfo.xmlns + } + commonParents := 0 + if xmlns == s.xmlns { + for ; commonParents < len(finfo.parents) && commonParents < len(s.parents); commonParents++ { + if finfo.parents[commonParents] != s.parents[commonParents] { + break + } + } + } + // Pop off any parents that aren't in common with the previous field. + for i := len(s.parents) - 1; i >= commonParents; i-- { + if err := s.p.writeEnd(Name{ + Space: s.xmlns, + Local: s.parents[i], + }); err != nil { + return err + } + } + s.parents = finfo.parents + s.xmlns = xmlns + if commonParents >= len(s.parents) { + // No new elements to push. + return nil + } + if (vf.Kind() == reflect.Ptr || vf.Kind() == reflect.Interface) && vf.IsNil() { + // The element is nil, so no need for the start elements. + s.parents = s.parents[:commonParents] + return nil + } + // Push any new parents required. + for _, name := range s.parents[commonParents:] { + start := &StartElement{ + Name: Name{ + Space: s.xmlns, + Local: name, + }, + } + // Set the default name space for parent elements + // to match what we do with other elements. + if s.xmlns != s.p.defaultNS { + start.setDefaultNamespace() + } + if err := s.p.writeStart(start); err != nil { + return err + } + } + return nil +} + +// A MarshalXMLError is returned when Marshal encounters a type +// that cannot be converted into XML. +type UnsupportedTypeError struct { + Type reflect.Type +} + +func (e *UnsupportedTypeError) Error() string { + return "xml: unsupported type: " + e.Type.String() +} + +func isEmptyValue(v reflect.Value) bool { + switch v.Kind() { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + return false +} diff --git a/server/webdav/internal/xml/marshal_test.go b/server/webdav/internal/xml/marshal_test.go new file mode 100644 index 0000000000000000000000000000000000000000..226cfd013f0493c5a7db2a918ca61691f67d4b06 --- /dev/null +++ b/server/webdav/internal/xml/marshal_test.go @@ -0,0 +1,1939 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import ( + "bytes" + "errors" + "fmt" + "io" + "reflect" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +type DriveType int + +const ( + HyperDrive DriveType = iota + ImprobabilityDrive +) + +type Passenger struct { + Name []string `xml:"name"` + Weight float32 `xml:"weight"` +} + +type Ship struct { + XMLName struct{} `xml:"spaceship"` + + Name string `xml:"name,attr"` + Pilot string `xml:"pilot,attr"` + Drive DriveType `xml:"drive"` + Age uint `xml:"age"` + Passenger []*Passenger `xml:"passenger"` + secret string +} + +type NamedType string + +type Port struct { + XMLName struct{} `xml:"port"` + Type string `xml:"type,attr,omitempty"` + Comment string `xml:",comment"` + Number string `xml:",chardata"` +} + +type Domain struct { + XMLName struct{} `xml:"domain"` + Country string `xml:",attr,omitempty"` + Name []byte `xml:",chardata"` + Comment []byte `xml:",comment"` +} + +type Book struct { + XMLName struct{} `xml:"book"` + Title string `xml:",chardata"` +} + +type Event struct { + XMLName struct{} `xml:"event"` + Year int `xml:",chardata"` +} + +type Movie struct { + XMLName struct{} `xml:"movie"` + Length uint `xml:",chardata"` +} + +type Pi struct { + XMLName struct{} `xml:"pi"` + Approximation float32 `xml:",chardata"` +} + +type Universe struct { + XMLName struct{} `xml:"universe"` + Visible float64 `xml:",chardata"` +} + +type Particle struct { + XMLName struct{} `xml:"particle"` + HasMass bool `xml:",chardata"` +} + +type Departure struct { + XMLName struct{} `xml:"departure"` + When time.Time `xml:",chardata"` +} + +type SecretAgent struct { + XMLName struct{} `xml:"agent"` + Handle string `xml:"handle,attr"` + Identity string + Obfuscate string `xml:",innerxml"` +} + +type NestedItems struct { + XMLName struct{} `xml:"result"` + Items []string `xml:">item"` + Item1 []string `xml:"Items>item1"` +} + +type NestedOrder struct { + XMLName struct{} `xml:"result"` + Field1 string `xml:"parent>c"` + Field2 string `xml:"parent>b"` + Field3 string `xml:"parent>a"` +} + +type MixedNested struct { + XMLName struct{} `xml:"result"` + A string `xml:"parent1>a"` + B string `xml:"b"` + C string `xml:"parent1>parent2>c"` + D string `xml:"parent1>d"` +} + +type NilTest struct { + A interface{} `xml:"parent1>parent2>a"` + B interface{} `xml:"parent1>b"` + C interface{} `xml:"parent1>parent2>c"` +} + +type Service struct { + XMLName struct{} `xml:"service"` + Domain *Domain `xml:"host>domain"` + Port *Port `xml:"host>port"` + Extra1 interface{} + Extra2 interface{} `xml:"host>extra2"` +} + +var nilStruct *Ship + +type EmbedA struct { + EmbedC + EmbedB EmbedB + FieldA string +} + +type EmbedB struct { + FieldB string + *EmbedC +} + +type EmbedC struct { + FieldA1 string `xml:"FieldA>A1"` + FieldA2 string `xml:"FieldA>A2"` + FieldB string + FieldC string +} + +type NameCasing struct { + XMLName struct{} `xml:"casing"` + Xy string + XY string + XyA string `xml:"Xy,attr"` + XYA string `xml:"XY,attr"` +} + +type NamePrecedence struct { + XMLName Name `xml:"Parent"` + FromTag XMLNameWithoutTag `xml:"InTag"` + FromNameVal XMLNameWithoutTag + FromNameTag XMLNameWithTag + InFieldName string +} + +type XMLNameWithTag struct { + XMLName Name `xml:"InXMLNameTag"` + Value string `xml:",chardata"` +} + +type XMLNameWithNSTag struct { + XMLName Name `xml:"ns InXMLNameWithNSTag"` + Value string `xml:",chardata"` +} + +type XMLNameWithoutTag struct { + XMLName Name + Value string `xml:",chardata"` +} + +type NameInField struct { + Foo Name `xml:"ns foo"` +} + +type AttrTest struct { + Int int `xml:",attr"` + Named int `xml:"int,attr"` + Float float64 `xml:",attr"` + Uint8 uint8 `xml:",attr"` + Bool bool `xml:",attr"` + Str string `xml:",attr"` + Bytes []byte `xml:",attr"` +} + +type OmitAttrTest struct { + Int int `xml:",attr,omitempty"` + Named int `xml:"int,attr,omitempty"` + Float float64 `xml:",attr,omitempty"` + Uint8 uint8 `xml:",attr,omitempty"` + Bool bool `xml:",attr,omitempty"` + Str string `xml:",attr,omitempty"` + Bytes []byte `xml:",attr,omitempty"` +} + +type OmitFieldTest struct { + Int int `xml:",omitempty"` + Named int `xml:"int,omitempty"` + Float float64 `xml:",omitempty"` + Uint8 uint8 `xml:",omitempty"` + Bool bool `xml:",omitempty"` + Str string `xml:",omitempty"` + Bytes []byte `xml:",omitempty"` + Ptr *PresenceTest `xml:",omitempty"` +} + +type AnyTest struct { + XMLName struct{} `xml:"a"` + Nested string `xml:"nested>value"` + AnyField AnyHolder `xml:",any"` +} + +type AnyOmitTest struct { + XMLName struct{} `xml:"a"` + Nested string `xml:"nested>value"` + AnyField *AnyHolder `xml:",any,omitempty"` +} + +type AnySliceTest struct { + XMLName struct{} `xml:"a"` + Nested string `xml:"nested>value"` + AnyField []AnyHolder `xml:",any"` +} + +type AnyHolder struct { + XMLName Name + XML string `xml:",innerxml"` +} + +type RecurseA struct { + A string + B *RecurseB +} + +type RecurseB struct { + A *RecurseA + B string +} + +type PresenceTest struct { + Exists *struct{} +} + +type IgnoreTest struct { + PublicSecret string `xml:"-"` +} + +type MyBytes []byte + +type Data struct { + Bytes []byte + Attr []byte `xml:",attr"` + Custom MyBytes +} + +type Plain struct { + V interface{} +} + +type MyInt int + +type EmbedInt struct { + MyInt +} + +type Strings struct { + X []string `xml:"A>B,omitempty"` +} + +type PointerFieldsTest struct { + XMLName Name `xml:"dummy"` + Name *string `xml:"name,attr"` + Age *uint `xml:"age,attr"` + Empty *string `xml:"empty,attr"` + Contents *string `xml:",chardata"` +} + +type ChardataEmptyTest struct { + XMLName Name `xml:"test"` + Contents *string `xml:",chardata"` +} + +type MyMarshalerTest struct { +} + +var _ Marshaler = (*MyMarshalerTest)(nil) + +func (m *MyMarshalerTest) MarshalXML(e *Encoder, start StartElement) error { + e.EncodeToken(start) + e.EncodeToken(CharData([]byte("hello world"))) + e.EncodeToken(EndElement{start.Name}) + return nil +} + +type MyMarshalerAttrTest struct{} + +var _ MarshalerAttr = (*MyMarshalerAttrTest)(nil) + +func (m *MyMarshalerAttrTest) MarshalXMLAttr(name Name) (Attr, error) { + return Attr{name, "hello world"}, nil +} + +type MyMarshalerValueAttrTest struct{} + +var _ MarshalerAttr = MyMarshalerValueAttrTest{} + +func (m MyMarshalerValueAttrTest) MarshalXMLAttr(name Name) (Attr, error) { + return Attr{name, "hello world"}, nil +} + +type MarshalerStruct struct { + Foo MyMarshalerAttrTest `xml:",attr"` +} + +type MarshalerValueStruct struct { + Foo MyMarshalerValueAttrTest `xml:",attr"` +} + +type InnerStruct struct { + XMLName Name `xml:"testns outer"` +} + +type OuterStruct struct { + InnerStruct + IntAttr int `xml:"int,attr"` +} + +type OuterNamedStruct struct { + InnerStruct + XMLName Name `xml:"outerns test"` + IntAttr int `xml:"int,attr"` +} + +type OuterNamedOrderedStruct struct { + XMLName Name `xml:"outerns test"` + InnerStruct + IntAttr int `xml:"int,attr"` +} + +type OuterOuterStruct struct { + OuterStruct +} + +type NestedAndChardata struct { + AB []string `xml:"A>B"` + Chardata string `xml:",chardata"` +} + +type NestedAndComment struct { + AB []string `xml:"A>B"` + Comment string `xml:",comment"` +} + +type XMLNSFieldStruct struct { + Ns string `xml:"xmlns,attr"` + Body string +} + +type NamedXMLNSFieldStruct struct { + XMLName struct{} `xml:"testns test"` + Ns string `xml:"xmlns,attr"` + Body string +} + +type XMLNSFieldStructWithOmitEmpty struct { + Ns string `xml:"xmlns,attr,omitempty"` + Body string +} + +type NamedXMLNSFieldStructWithEmptyNamespace struct { + XMLName struct{} `xml:"test"` + Ns string `xml:"xmlns,attr"` + Body string +} + +type RecursiveXMLNSFieldStruct struct { + Ns string `xml:"xmlns,attr"` + Body *RecursiveXMLNSFieldStruct `xml:",omitempty"` + Text string `xml:",omitempty"` +} + +func ifaceptr(x interface{}) interface{} { + return &x +} + +var ( + nameAttr = "Sarah" + ageAttr = uint(12) + contentsAttr = "lorem ipsum" +) + +// Unless explicitly stated as such (or *Plain), all of the +// tests below are two-way tests. When introducing new tests, +// please try to make them two-way as well to ensure that +// marshalling and unmarshalling are as symmetrical as feasible. +var marshalTests = []struct { + Value interface{} + ExpectXML string + MarshalOnly bool + UnmarshalOnly bool +}{ + // Test nil marshals to nothing + {Value: nil, ExpectXML: ``, MarshalOnly: true}, + {Value: nilStruct, ExpectXML: ``, MarshalOnly: true}, + + // Test value types + {Value: &Plain{true}, ExpectXML: `true`}, + {Value: &Plain{false}, ExpectXML: `false`}, + {Value: &Plain{int(42)}, ExpectXML: `42`}, + {Value: &Plain{int8(42)}, ExpectXML: `42`}, + {Value: &Plain{int16(42)}, ExpectXML: `42`}, + {Value: &Plain{int32(42)}, ExpectXML: `42`}, + {Value: &Plain{uint(42)}, ExpectXML: `42`}, + {Value: &Plain{uint8(42)}, ExpectXML: `42`}, + {Value: &Plain{uint16(42)}, ExpectXML: `42`}, + {Value: &Plain{uint32(42)}, ExpectXML: `42`}, + {Value: &Plain{float32(1.25)}, ExpectXML: `1.25`}, + {Value: &Plain{float64(1.25)}, ExpectXML: `1.25`}, + {Value: &Plain{uintptr(0xFFDD)}, ExpectXML: `65501`}, + {Value: &Plain{"gopher"}, ExpectXML: `gopher`}, + {Value: &Plain{[]byte("gopher")}, ExpectXML: `gopher`}, + {Value: &Plain{""}, ExpectXML: `</>`}, + {Value: &Plain{[]byte("")}, ExpectXML: `</>`}, + {Value: &Plain{[3]byte{'<', '/', '>'}}, ExpectXML: `</>`}, + {Value: &Plain{NamedType("potato")}, ExpectXML: `potato`}, + {Value: &Plain{[]int{1, 2, 3}}, ExpectXML: `123`}, + {Value: &Plain{[3]int{1, 2, 3}}, ExpectXML: `123`}, + {Value: ifaceptr(true), MarshalOnly: true, ExpectXML: `true`}, + + // Test time. + { + Value: &Plain{time.Unix(1e9, 123456789).UTC()}, + ExpectXML: `2001-09-09T01:46:40.123456789Z`, + }, + + // A pointer to struct{} may be used to test for an element's presence. + { + Value: &PresenceTest{new(struct{})}, + ExpectXML: ``, + }, + { + Value: &PresenceTest{}, + ExpectXML: ``, + }, + + // A pointer to struct{} may be used to test for an element's presence. + { + Value: &PresenceTest{new(struct{})}, + ExpectXML: ``, + }, + { + Value: &PresenceTest{}, + ExpectXML: ``, + }, + + // A []byte field is only nil if the element was not found. + { + Value: &Data{}, + ExpectXML: ``, + UnmarshalOnly: true, + }, + { + Value: &Data{Bytes: []byte{}, Custom: MyBytes{}, Attr: []byte{}}, + ExpectXML: ``, + UnmarshalOnly: true, + }, + + // Check that []byte works, including named []byte types. + { + Value: &Data{Bytes: []byte("ab"), Custom: MyBytes("cd"), Attr: []byte{'v'}}, + ExpectXML: `abcd`, + }, + + // Test innerxml + { + Value: &SecretAgent{ + Handle: "007", + Identity: "James Bond", + Obfuscate: "", + }, + ExpectXML: `James Bond`, + MarshalOnly: true, + }, + { + Value: &SecretAgent{ + Handle: "007", + Identity: "James Bond", + Obfuscate: "James Bond", + }, + ExpectXML: `James Bond`, + UnmarshalOnly: true, + }, + + // Test structs + {Value: &Port{Type: "ssl", Number: "443"}, ExpectXML: `443`}, + {Value: &Port{Number: "443"}, ExpectXML: `443`}, + {Value: &Port{Type: ""}, ExpectXML: ``}, + {Value: &Port{Number: "443", Comment: "https"}, ExpectXML: `443`}, + {Value: &Port{Number: "443", Comment: "add space-"}, ExpectXML: `443`, MarshalOnly: true}, + {Value: &Domain{Name: []byte("google.com&friends")}, ExpectXML: `google.com&friends`}, + {Value: &Domain{Name: []byte("google.com"), Comment: []byte(" &friends ")}, ExpectXML: `google.com`}, + {Value: &Book{Title: "Pride & Prejudice"}, ExpectXML: `Pride & Prejudice`}, + {Value: &Event{Year: -3114}, ExpectXML: `-3114`}, + {Value: &Movie{Length: 13440}, ExpectXML: `13440`}, + {Value: &Pi{Approximation: 3.14159265}, ExpectXML: `3.1415927`}, + {Value: &Universe{Visible: 9.3e13}, ExpectXML: `9.3e+13`}, + {Value: &Particle{HasMass: true}, ExpectXML: `true`}, + {Value: &Departure{When: ParseTime("2013-01-09T00:15:00-09:00")}, ExpectXML: `2013-01-09T00:15:00-09:00`}, + {Value: atomValue, ExpectXML: atomXml}, + { + Value: &Ship{ + Name: "Heart of Gold", + Pilot: "Computer", + Age: 1, + Drive: ImprobabilityDrive, + Passenger: []*Passenger{ + { + Name: []string{"Zaphod", "Beeblebrox"}, + Weight: 7.25, + }, + { + Name: []string{"Trisha", "McMillen"}, + Weight: 5.5, + }, + { + Name: []string{"Ford", "Prefect"}, + Weight: 7, + }, + { + Name: []string{"Arthur", "Dent"}, + Weight: 6.75, + }, + }, + }, + ExpectXML: `` + + `` + strconv.Itoa(int(ImprobabilityDrive)) + `` + + `1` + + `` + + `Zaphod` + + `Beeblebrox` + + `7.25` + + `` + + `` + + `Trisha` + + `McMillen` + + `5.5` + + `` + + `` + + `Ford` + + `Prefect` + + `7` + + `` + + `` + + `Arthur` + + `Dent` + + `6.75` + + `` + + ``, + }, + + // Test a>b + { + Value: &NestedItems{Items: nil, Item1: nil}, + ExpectXML: `` + + `` + + `` + + ``, + }, + { + Value: &NestedItems{Items: []string{}, Item1: []string{}}, + ExpectXML: `` + + `` + + `` + + ``, + MarshalOnly: true, + }, + { + Value: &NestedItems{Items: nil, Item1: []string{"A"}}, + ExpectXML: `` + + `` + + `A` + + `` + + ``, + }, + { + Value: &NestedItems{Items: []string{"A", "B"}, Item1: nil}, + ExpectXML: `` + + `` + + `A` + + `B` + + `` + + ``, + }, + { + Value: &NestedItems{Items: []string{"A", "B"}, Item1: []string{"C"}}, + ExpectXML: `` + + `` + + `A` + + `B` + + `C` + + `` + + ``, + }, + { + Value: &NestedOrder{Field1: "C", Field2: "B", Field3: "A"}, + ExpectXML: `` + + `` + + `C` + + `B` + + `A` + + `` + + ``, + }, + { + Value: &NilTest{A: "A", B: nil, C: "C"}, + ExpectXML: `` + + `` + + `A` + + `C` + + `` + + ``, + MarshalOnly: true, // Uses interface{} + }, + { + Value: &MixedNested{A: "A", B: "B", C: "C", D: "D"}, + ExpectXML: `` + + `A` + + `B` + + `` + + `C` + + `D` + + `` + + ``, + }, + { + Value: &Service{Port: &Port{Number: "80"}}, + ExpectXML: `80`, + }, + { + Value: &Service{}, + ExpectXML: ``, + }, + { + Value: &Service{Port: &Port{Number: "80"}, Extra1: "A", Extra2: "B"}, + ExpectXML: `` + + `80` + + `A` + + `B` + + ``, + MarshalOnly: true, + }, + { + Value: &Service{Port: &Port{Number: "80"}, Extra2: "example"}, + ExpectXML: `` + + `80` + + `example` + + ``, + MarshalOnly: true, + }, + { + Value: &struct { + XMLName struct{} `xml:"space top"` + A string `xml:"x>a"` + B string `xml:"x>b"` + C string `xml:"space x>c"` + C1 string `xml:"space1 x>c"` + D1 string `xml:"space1 x>d"` + E1 string `xml:"x>e"` + }{ + A: "a", + B: "b", + C: "c", + C1: "c1", + D1: "d1", + E1: "e1", + }, + ExpectXML: `` + + `abc` + + `` + + `c1` + + `d1` + + `` + + `` + + `e1` + + `` + + ``, + }, + { + Value: &struct { + XMLName Name + A string `xml:"x>a"` + B string `xml:"x>b"` + C string `xml:"space x>c"` + C1 string `xml:"space1 x>c"` + D1 string `xml:"space1 x>d"` + }{ + XMLName: Name{ + Space: "space0", + Local: "top", + }, + A: "a", + B: "b", + C: "c", + C1: "c1", + D1: "d1", + }, + ExpectXML: `` + + `ab` + + `c` + + `` + + `c1` + + `d1` + + `` + + ``, + }, + { + Value: &struct { + XMLName struct{} `xml:"top"` + B string `xml:"space x>b"` + B1 string `xml:"space1 x>b"` + }{ + B: "b", + B1: "b1", + }, + ExpectXML: `` + + `b` + + `b1` + + ``, + }, + + // Test struct embedding + { + Value: &EmbedA{ + EmbedC: EmbedC{ + FieldA1: "", // Shadowed by A.A + FieldA2: "", // Shadowed by A.A + FieldB: "A.C.B", + FieldC: "A.C.C", + }, + EmbedB: EmbedB{ + FieldB: "A.B.B", + EmbedC: &EmbedC{ + FieldA1: "A.B.C.A1", + FieldA2: "A.B.C.A2", + FieldB: "", // Shadowed by A.B.B + FieldC: "A.B.C.C", + }, + }, + FieldA: "A.A", + }, + ExpectXML: `` + + `A.C.B` + + `A.C.C` + + `` + + `A.B.B` + + `` + + `A.B.C.A1` + + `A.B.C.A2` + + `` + + `A.B.C.C` + + `` + + `A.A` + + ``, + }, + + // Test that name casing matters + { + Value: &NameCasing{Xy: "mixed", XY: "upper", XyA: "mixedA", XYA: "upperA"}, + ExpectXML: `mixedupper`, + }, + + // Test the order in which the XML element name is chosen + { + Value: &NamePrecedence{ + FromTag: XMLNameWithoutTag{Value: "A"}, + FromNameVal: XMLNameWithoutTag{XMLName: Name{Local: "InXMLName"}, Value: "B"}, + FromNameTag: XMLNameWithTag{Value: "C"}, + InFieldName: "D", + }, + ExpectXML: `` + + `A` + + `B` + + `C` + + `D` + + ``, + MarshalOnly: true, + }, + { + Value: &NamePrecedence{ + XMLName: Name{Local: "Parent"}, + FromTag: XMLNameWithoutTag{XMLName: Name{Local: "InTag"}, Value: "A"}, + FromNameVal: XMLNameWithoutTag{XMLName: Name{Local: "FromNameVal"}, Value: "B"}, + FromNameTag: XMLNameWithTag{XMLName: Name{Local: "InXMLNameTag"}, Value: "C"}, + InFieldName: "D", + }, + ExpectXML: `` + + `A` + + `B` + + `C` + + `D` + + ``, + UnmarshalOnly: true, + }, + + // xml.Name works in a plain field as well. + { + Value: &NameInField{Name{Space: "ns", Local: "foo"}}, + ExpectXML: ``, + }, + { + Value: &NameInField{Name{Space: "ns", Local: "foo"}}, + ExpectXML: ``, + UnmarshalOnly: true, + }, + + // Marshaling zero xml.Name uses the tag or field name. + { + Value: &NameInField{}, + ExpectXML: ``, + MarshalOnly: true, + }, + + // Test attributes + { + Value: &AttrTest{ + Int: 8, + Named: 9, + Float: 23.5, + Uint8: 255, + Bool: true, + Str: "str", + Bytes: []byte("byt"), + }, + ExpectXML: ``, + }, + { + Value: &AttrTest{Bytes: []byte{}}, + ExpectXML: ``, + }, + { + Value: &OmitAttrTest{ + Int: 8, + Named: 9, + Float: 23.5, + Uint8: 255, + Bool: true, + Str: "str", + Bytes: []byte("byt"), + }, + ExpectXML: ``, + }, + { + Value: &OmitAttrTest{}, + ExpectXML: ``, + }, + + // pointer fields + { + Value: &PointerFieldsTest{Name: &nameAttr, Age: &ageAttr, Contents: &contentsAttr}, + ExpectXML: `lorem ipsum`, + MarshalOnly: true, + }, + + // empty chardata pointer field + { + Value: &ChardataEmptyTest{}, + ExpectXML: ``, + MarshalOnly: true, + }, + + // omitempty on fields + { + Value: &OmitFieldTest{ + Int: 8, + Named: 9, + Float: 23.5, + Uint8: 255, + Bool: true, + Str: "str", + Bytes: []byte("byt"), + Ptr: &PresenceTest{}, + }, + ExpectXML: `` + + `8` + + `9` + + `23.5` + + `255` + + `true` + + `str` + + `byt` + + `` + + ``, + }, + { + Value: &OmitFieldTest{}, + ExpectXML: ``, + }, + + // Test ",any" + { + ExpectXML: `knownunknown`, + Value: &AnyTest{ + Nested: "known", + AnyField: AnyHolder{ + XMLName: Name{Local: "other"}, + XML: "unknown", + }, + }, + }, + { + Value: &AnyTest{Nested: "known", + AnyField: AnyHolder{ + XML: "", + XMLName: Name{Local: "AnyField"}, + }, + }, + ExpectXML: `known`, + }, + { + ExpectXML: `b`, + Value: &AnyOmitTest{ + Nested: "b", + }, + }, + { + ExpectXML: `bei`, + Value: &AnySliceTest{ + Nested: "b", + AnyField: []AnyHolder{ + { + XMLName: Name{Local: "c"}, + XML: "e", + }, + { + XMLName: Name{Space: "f", Local: "g"}, + XML: "i", + }, + }, + }, + }, + { + ExpectXML: `b`, + Value: &AnySliceTest{ + Nested: "b", + }, + }, + + // Test recursive types. + { + Value: &RecurseA{ + A: "a1", + B: &RecurseB{ + A: &RecurseA{"a2", nil}, + B: "b1", + }, + }, + ExpectXML: `a1a2b1`, + }, + + // Test ignoring fields via "-" tag + { + ExpectXML: ``, + Value: &IgnoreTest{}, + }, + { + ExpectXML: ``, + Value: &IgnoreTest{PublicSecret: "can't tell"}, + MarshalOnly: true, + }, + { + ExpectXML: `ignore me`, + Value: &IgnoreTest{}, + UnmarshalOnly: true, + }, + + // Test escaping. + { + ExpectXML: `dquote: "; squote: '; ampersand: &; less: <; greater: >;`, + Value: &AnyTest{ + Nested: `dquote: "; squote: '; ampersand: &; less: <; greater: >;`, + AnyField: AnyHolder{XMLName: Name{Local: "empty"}}, + }, + }, + { + ExpectXML: `newline: ; cr: ; tab: ;`, + Value: &AnyTest{ + Nested: "newline: \n; cr: \r; tab: \t;", + AnyField: AnyHolder{XMLName: Name{Local: "AnyField"}}, + }, + }, + { + ExpectXML: "1\r2\r\n3\n\r4\n5", + Value: &AnyTest{ + Nested: "1\n2\n3\n\n4\n5", + }, + UnmarshalOnly: true, + }, + { + ExpectXML: `42`, + Value: &EmbedInt{ + MyInt: 42, + }, + }, + // Test omitempty with parent chain; see golang.org/issue/4168. + { + ExpectXML: ``, + Value: &Strings{}, + }, + // Custom marshalers. + { + ExpectXML: `hello world`, + Value: &MyMarshalerTest{}, + }, + { + ExpectXML: ``, + Value: &MarshalerStruct{}, + }, + { + ExpectXML: ``, + Value: &MarshalerValueStruct{}, + }, + { + ExpectXML: ``, + Value: &OuterStruct{IntAttr: 10}, + }, + { + ExpectXML: ``, + Value: &OuterNamedStruct{XMLName: Name{Space: "outerns", Local: "test"}, IntAttr: 10}, + }, + { + ExpectXML: ``, + Value: &OuterNamedOrderedStruct{XMLName: Name{Space: "outerns", Local: "test"}, IntAttr: 10}, + }, + { + ExpectXML: ``, + Value: &OuterOuterStruct{OuterStruct{IntAttr: 10}}, + }, + { + ExpectXML: `test`, + Value: &NestedAndChardata{AB: make([]string, 2), Chardata: "test"}, + }, + { + ExpectXML: ``, + Value: &NestedAndComment{AB: make([]string, 2), Comment: "test"}, + }, + { + ExpectXML: `hello world`, + Value: &XMLNSFieldStruct{Ns: "http://example.com/ns", Body: "hello world"}, + }, + { + ExpectXML: `hello world`, + Value: &NamedXMLNSFieldStruct{Ns: "http://example.com/ns", Body: "hello world"}, + }, + { + ExpectXML: `hello world`, + Value: &NamedXMLNSFieldStruct{Ns: "", Body: "hello world"}, + }, + { + ExpectXML: `hello world`, + Value: &XMLNSFieldStructWithOmitEmpty{Body: "hello world"}, + }, + { + // The xmlns attribute must be ignored because the + // element is in the empty namespace, so it's not possible + // to set the default namespace to something non-empty. + ExpectXML: `hello world`, + Value: &NamedXMLNSFieldStructWithEmptyNamespace{Ns: "foo", Body: "hello world"}, + MarshalOnly: true, + }, + { + ExpectXML: `hello world`, + Value: &RecursiveXMLNSFieldStruct{ + Ns: "foo", + Body: &RecursiveXMLNSFieldStruct{ + Text: "hello world", + }, + }, + }, +} + +func TestMarshal(t *testing.T) { + for idx, test := range marshalTests { + if test.UnmarshalOnly { + continue + } + data, err := Marshal(test.Value) + if err != nil { + t.Errorf("#%d: marshal(%#v): %s", idx, test.Value, err) + continue + } + if got, want := string(data), test.ExpectXML; got != want { + if strings.Contains(want, "\n") { + t.Errorf("#%d: marshal(%#v):\nHAVE:\n%s\nWANT:\n%s", idx, test.Value, got, want) + } else { + t.Errorf("#%d: marshal(%#v):\nhave %#q\nwant %#q", idx, test.Value, got, want) + } + } + } +} + +type AttrParent struct { + X string `xml:"X>Y,attr"` +} + +type BadAttr struct { + Name []string `xml:"name,attr"` +} + +var marshalErrorTests = []struct { + Value interface{} + Err string + Kind reflect.Kind +}{ + { + Value: make(chan bool), + Err: "xml: unsupported type: chan bool", + Kind: reflect.Chan, + }, + { + Value: map[string]string{ + "question": "What do you get when you multiply six by nine?", + "answer": "42", + }, + Err: "xml: unsupported type: map[string]string", + Kind: reflect.Map, + }, + { + Value: map[*Ship]bool{nil: false}, + Err: "xml: unsupported type: map[*xml.Ship]bool", + Kind: reflect.Map, + }, + { + Value: &Domain{Comment: []byte("f--bar")}, + Err: `xml: comments must not contain "--"`, + }, + // Reject parent chain with attr, never worked; see golang.org/issue/5033. + { + Value: &AttrParent{}, + Err: `xml: X>Y chain not valid with attr flag`, + }, + { + Value: BadAttr{[]string{"X", "Y"}}, + Err: `xml: unsupported type: []string`, + }, +} + +var marshalIndentTests = []struct { + Value interface{} + Prefix string + Indent string + ExpectXML string +}{ + { + Value: &SecretAgent{ + Handle: "007", + Identity: "James Bond", + Obfuscate: "", + }, + Prefix: "", + Indent: "\t", + ExpectXML: fmt.Sprintf("\n\tJames Bond\n"), + }, +} + +func TestMarshalErrors(t *testing.T) { + for idx, test := range marshalErrorTests { + data, err := Marshal(test.Value) + if err == nil { + t.Errorf("#%d: marshal(%#v) = [success] %q, want error %v", idx, test.Value, data, test.Err) + continue + } + if err.Error() != test.Err { + t.Errorf("#%d: marshal(%#v) = [error] %v, want %v", idx, test.Value, err, test.Err) + } + if test.Kind != reflect.Invalid { + if kind := err.(*UnsupportedTypeError).Type.Kind(); kind != test.Kind { + t.Errorf("#%d: marshal(%#v) = [error kind] %s, want %s", idx, test.Value, kind, test.Kind) + } + } + } +} + +// Do invertibility testing on the various structures that we test +func TestUnmarshal(t *testing.T) { + for i, test := range marshalTests { + if test.MarshalOnly { + continue + } + if _, ok := test.Value.(*Plain); ok { + continue + } + vt := reflect.TypeOf(test.Value) + dest := reflect.New(vt.Elem()).Interface() + err := Unmarshal([]byte(test.ExpectXML), dest) + + switch fix := dest.(type) { + case *Feed: + fix.Author.InnerXML = "" + for i := range fix.Entry { + fix.Entry[i].Author.InnerXML = "" + } + } + + if err != nil { + t.Errorf("#%d: unexpected error: %#v", i, err) + } else if got, want := dest, test.Value; !reflect.DeepEqual(got, want) { + t.Errorf("#%d: unmarshal(%q):\nhave %#v\nwant %#v", i, test.ExpectXML, got, want) + } + } +} + +func TestMarshalIndent(t *testing.T) { + for i, test := range marshalIndentTests { + data, err := MarshalIndent(test.Value, test.Prefix, test.Indent) + if err != nil { + t.Errorf("#%d: Error: %s", i, err) + continue + } + if got, want := string(data), test.ExpectXML; got != want { + t.Errorf("#%d: MarshalIndent:\nGot:%s\nWant:\n%s", i, got, want) + } + } +} + +type limitedBytesWriter struct { + w io.Writer + remain int // until writes fail +} + +func (lw *limitedBytesWriter) Write(p []byte) (n int, err error) { + if lw.remain <= 0 { + println("error") + return 0, errors.New("write limit hit") + } + if len(p) > lw.remain { + p = p[:lw.remain] + n, _ = lw.w.Write(p) + lw.remain = 0 + return n, errors.New("write limit hit") + } + n, err = lw.w.Write(p) + lw.remain -= n + return n, err +} + +func TestMarshalWriteErrors(t *testing.T) { + var buf bytes.Buffer + const writeCap = 1024 + w := &limitedBytesWriter{&buf, writeCap} + enc := NewEncoder(w) + var err error + var i int + const n = 4000 + for i = 1; i <= n; i++ { + err = enc.Encode(&Passenger{ + Name: []string{"Alice", "Bob"}, + Weight: 5, + }) + if err != nil { + break + } + } + if err == nil { + t.Error("expected an error") + } + if i == n { + t.Errorf("expected to fail before the end") + } + if buf.Len() != writeCap { + t.Errorf("buf.Len() = %d; want %d", buf.Len(), writeCap) + } +} + +func TestMarshalWriteIOErrors(t *testing.T) { + enc := NewEncoder(errWriter{}) + + expectErr := "unwritable" + err := enc.Encode(&Passenger{}) + if err == nil || err.Error() != expectErr { + t.Errorf("EscapeTest = [error] %v, want %v", err, expectErr) + } +} + +func TestMarshalFlush(t *testing.T) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + if err := enc.EncodeToken(CharData("hello world")); err != nil { + t.Fatalf("enc.EncodeToken: %v", err) + } + if buf.Len() > 0 { + t.Fatalf("enc.EncodeToken caused actual write: %q", buf.Bytes()) + } + if err := enc.Flush(); err != nil { + t.Fatalf("enc.Flush: %v", err) + } + if buf.String() != "hello world" { + t.Fatalf("after enc.Flush, buf.String() = %q, want %q", buf.String(), "hello world") + } +} + +var encodeElementTests = []struct { + desc string + value interface{} + start StartElement + expectXML string +}{{ + desc: "simple string", + value: "hello", + start: StartElement{ + Name: Name{Local: "a"}, + }, + expectXML: `hello`, +}, { + desc: "string with added attributes", + value: "hello", + start: StartElement{ + Name: Name{Local: "a"}, + Attr: []Attr{{ + Name: Name{Local: "x"}, + Value: "y", + }, { + Name: Name{Local: "foo"}, + Value: "bar", + }}, + }, + expectXML: `hello`, +}, { + desc: "start element with default name space", + value: struct { + Foo XMLNameWithNSTag + }{ + Foo: XMLNameWithNSTag{ + Value: "hello", + }, + }, + start: StartElement{ + Name: Name{Space: "ns", Local: "a"}, + Attr: []Attr{{ + Name: Name{Local: "xmlns"}, + // "ns" is the name space defined in XMLNameWithNSTag + Value: "ns", + }}, + }, + expectXML: `hello`, +}, { + desc: "start element in name space with different default name space", + value: struct { + Foo XMLNameWithNSTag + }{ + Foo: XMLNameWithNSTag{ + Value: "hello", + }, + }, + start: StartElement{ + Name: Name{Space: "ns2", Local: "a"}, + Attr: []Attr{{ + Name: Name{Local: "xmlns"}, + // "ns" is the name space defined in XMLNameWithNSTag + Value: "ns", + }}, + }, + expectXML: `hello`, +}, { + desc: "XMLMarshaler with start element with default name space", + value: &MyMarshalerTest{}, + start: StartElement{ + Name: Name{Space: "ns2", Local: "a"}, + Attr: []Attr{{ + Name: Name{Local: "xmlns"}, + // "ns" is the name space defined in XMLNameWithNSTag + Value: "ns", + }}, + }, + expectXML: `hello world`, +}} + +func TestEncodeElement(t *testing.T) { + for idx, test := range encodeElementTests { + var buf bytes.Buffer + enc := NewEncoder(&buf) + err := enc.EncodeElement(test.value, test.start) + if err != nil { + t.Fatalf("enc.EncodeElement: %v", err) + } + err = enc.Flush() + if err != nil { + t.Fatalf("enc.Flush: %v", err) + } + if got, want := buf.String(), test.expectXML; got != want { + t.Errorf("#%d(%s): EncodeElement(%#v, %#v):\nhave %#q\nwant %#q", idx, test.desc, test.value, test.start, got, want) + } + } +} + +func BenchmarkMarshal(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + Marshal(atomValue) + } +} + +func BenchmarkUnmarshal(b *testing.B) { + b.ReportAllocs() + xml := []byte(atomXml) + for i := 0; i < b.N; i++ { + Unmarshal(xml, &Feed{}) + } +} + +// golang.org/issue/6556 +func TestStructPointerMarshal(t *testing.T) { + type A struct { + XMLName string `xml:"a"` + B []interface{} + } + type C struct { + XMLName Name + Value string `xml:"value"` + } + + a := new(A) + a.B = append(a.B, &C{ + XMLName: Name{Local: "c"}, + Value: "x", + }) + + b, err := Marshal(a) + if err != nil { + t.Fatal(err) + } + if x := string(b); x != "x" { + t.Fatal(x) + } + var v A + err = Unmarshal(b, &v) + if err != nil { + t.Fatal(err) + } +} + +var encodeTokenTests = []struct { + desc string + toks []Token + want string + err string +}{{ + desc: "start element with name space", + toks: []Token{ + StartElement{Name{"space", "local"}, nil}, + }, + want: ``, +}, { + desc: "start element with no name", + toks: []Token{ + StartElement{Name{"space", ""}, nil}, + }, + err: "xml: start tag with no name", +}, { + desc: "end element with no name", + toks: []Token{ + EndElement{Name{"space", ""}}, + }, + err: "xml: end tag with no name", +}, { + desc: "char data", + toks: []Token{ + CharData("foo"), + }, + want: `foo`, +}, { + desc: "char data with escaped chars", + toks: []Token{ + CharData(" \t\n"), + }, + want: " \n", +}, { + desc: "comment", + toks: []Token{ + Comment("foo"), + }, + want: ``, +}, { + desc: "comment with invalid content", + toks: []Token{ + Comment("foo-->"), + }, + err: "xml: EncodeToken of Comment containing --> marker", +}, { + desc: "proc instruction", + toks: []Token{ + ProcInst{"Target", []byte("Instruction")}, + }, + want: ``, +}, { + desc: "proc instruction with empty target", + toks: []Token{ + ProcInst{"", []byte("Instruction")}, + }, + err: "xml: EncodeToken of ProcInst with invalid Target", +}, { + desc: "proc instruction with bad content", + toks: []Token{ + ProcInst{"", []byte("Instruction?>")}, + }, + err: "xml: EncodeToken of ProcInst with invalid Target", +}, { + desc: "directive", + toks: []Token{ + Directive("foo"), + }, + want: ``, +}, { + desc: "more complex directive", + toks: []Token{ + Directive("DOCTYPE doc [ '> ]"), + }, + want: `'> ]>`, +}, { + desc: "directive instruction with bad name", + toks: []Token{ + Directive("foo>"), + }, + err: "xml: EncodeToken of Directive containing wrong < or > markers", +}, { + desc: "end tag without start tag", + toks: []Token{ + EndElement{Name{"foo", "bar"}}, + }, + err: "xml: end tag without start tag", +}, { + desc: "mismatching end tag local name", + toks: []Token{ + StartElement{Name{"", "foo"}, nil}, + EndElement{Name{"", "bar"}}, + }, + err: "xml: end tag does not match start tag ", + want: ``, +}, { + desc: "mismatching end tag namespace", + toks: []Token{ + StartElement{Name{"space", "foo"}, nil}, + EndElement{Name{"another", "foo"}}, + }, + err: "xml: end tag in namespace another does not match start tag in namespace space", + want: ``, +}, { + desc: "start element with explicit namespace", + toks: []Token{ + StartElement{Name{"space", "local"}, []Attr{ + {Name{"xmlns", "x"}, "space"}, + {Name{"space", "foo"}, "value"}, + }}, + }, + want: ``, +}, { + desc: "start element with explicit namespace and colliding prefix", + toks: []Token{ + StartElement{Name{"space", "local"}, []Attr{ + {Name{"xmlns", "x"}, "space"}, + {Name{"space", "foo"}, "value"}, + {Name{"x", "bar"}, "other"}, + }}, + }, + want: ``, +}, { + desc: "start element using previously defined namespace", + toks: []Token{ + StartElement{Name{"", "local"}, []Attr{ + {Name{"xmlns", "x"}, "space"}, + }}, + StartElement{Name{"space", "foo"}, []Attr{ + {Name{"space", "x"}, "y"}, + }}, + }, + want: ``, +}, { + desc: "nested name space with same prefix", + toks: []Token{ + StartElement{Name{"", "foo"}, []Attr{ + {Name{"xmlns", "x"}, "space1"}, + }}, + StartElement{Name{"", "foo"}, []Attr{ + {Name{"xmlns", "x"}, "space2"}, + }}, + StartElement{Name{"", "foo"}, []Attr{ + {Name{"space1", "a"}, "space1 value"}, + {Name{"space2", "b"}, "space2 value"}, + }}, + EndElement{Name{"", "foo"}}, + EndElement{Name{"", "foo"}}, + StartElement{Name{"", "foo"}, []Attr{ + {Name{"space1", "a"}, "space1 value"}, + {Name{"space2", "b"}, "space2 value"}, + }}, + }, + want: ``, +}, { + desc: "start element defining several prefixes for the same name space", + toks: []Token{ + StartElement{Name{"space", "foo"}, []Attr{ + {Name{"xmlns", "a"}, "space"}, + {Name{"xmlns", "b"}, "space"}, + {Name{"space", "x"}, "value"}, + }}, + }, + want: ``, +}, { + desc: "nested element redefines name space", + toks: []Token{ + StartElement{Name{"", "foo"}, []Attr{ + {Name{"xmlns", "x"}, "space"}, + }}, + StartElement{Name{"space", "foo"}, []Attr{ + {Name{"xmlns", "y"}, "space"}, + {Name{"space", "a"}, "value"}, + }}, + }, + want: ``, +}, { + desc: "nested element creates alias for default name space", + toks: []Token{ + StartElement{Name{"space", "foo"}, []Attr{ + {Name{"", "xmlns"}, "space"}, + }}, + StartElement{Name{"space", "foo"}, []Attr{ + {Name{"xmlns", "y"}, "space"}, + {Name{"space", "a"}, "value"}, + }}, + }, + want: ``, +}, { + desc: "nested element defines default name space with existing prefix", + toks: []Token{ + StartElement{Name{"", "foo"}, []Attr{ + {Name{"xmlns", "x"}, "space"}, + }}, + StartElement{Name{"space", "foo"}, []Attr{ + {Name{"", "xmlns"}, "space"}, + {Name{"space", "a"}, "value"}, + }}, + }, + want: ``, +}, { + desc: "nested element uses empty attribute name space when default ns defined", + toks: []Token{ + StartElement{Name{"space", "foo"}, []Attr{ + {Name{"", "xmlns"}, "space"}, + }}, + StartElement{Name{"space", "foo"}, []Attr{ + {Name{"", "attr"}, "value"}, + }}, + }, + want: ``, +}, { + desc: "redefine xmlns", + toks: []Token{ + StartElement{Name{"", "foo"}, []Attr{ + {Name{"foo", "xmlns"}, "space"}, + }}, + }, + err: `xml: cannot redefine xmlns attribute prefix`, +}, { + desc: "xmlns with explicit name space #1", + toks: []Token{ + StartElement{Name{"space", "foo"}, []Attr{ + {Name{"xml", "xmlns"}, "space"}, + }}, + }, + want: ``, +}, { + desc: "xmlns with explicit name space #2", + toks: []Token{ + StartElement{Name{"space", "foo"}, []Attr{ + {Name{xmlURL, "xmlns"}, "space"}, + }}, + }, + want: ``, +}, { + desc: "empty name space declaration is ignored", + toks: []Token{ + StartElement{Name{"", "foo"}, []Attr{ + {Name{"xmlns", "foo"}, ""}, + }}, + }, + want: ``, +}, { + desc: "attribute with no name is ignored", + toks: []Token{ + StartElement{Name{"", "foo"}, []Attr{ + {Name{"", ""}, "value"}, + }}, + }, + want: ``, +}, { + desc: "namespace URL with non-valid name", + toks: []Token{ + StartElement{Name{"/34", "foo"}, []Attr{ + {Name{"/34", "x"}, "value"}, + }}, + }, + want: `<_:foo xmlns:_="/34" _:x="value">`, +}, { + desc: "nested element resets default namespace to empty", + toks: []Token{ + StartElement{Name{"space", "foo"}, []Attr{ + {Name{"", "xmlns"}, "space"}, + }}, + StartElement{Name{"", "foo"}, []Attr{ + {Name{"", "xmlns"}, ""}, + {Name{"", "x"}, "value"}, + {Name{"space", "x"}, "value"}, + }}, + }, + want: ``, +}, { + desc: "nested element requires empty default name space", + toks: []Token{ + StartElement{Name{"space", "foo"}, []Attr{ + {Name{"", "xmlns"}, "space"}, + }}, + StartElement{Name{"", "foo"}, nil}, + }, + want: ``, +}, { + desc: "attribute uses name space from xmlns", + toks: []Token{ + StartElement{Name{"some/space", "foo"}, []Attr{ + {Name{"", "attr"}, "value"}, + {Name{"some/space", "other"}, "other value"}, + }}, + }, + want: ``, +}, { + desc: "default name space should not be used by attributes", + toks: []Token{ + StartElement{Name{"space", "foo"}, []Attr{ + {Name{"", "xmlns"}, "space"}, + {Name{"xmlns", "bar"}, "space"}, + {Name{"space", "baz"}, "foo"}, + }}, + StartElement{Name{"space", "baz"}, nil}, + EndElement{Name{"space", "baz"}}, + EndElement{Name{"space", "foo"}}, + }, + want: ``, +}, { + desc: "default name space not used by attributes, not explicitly defined", + toks: []Token{ + StartElement{Name{"space", "foo"}, []Attr{ + {Name{"", "xmlns"}, "space"}, + {Name{"space", "baz"}, "foo"}, + }}, + StartElement{Name{"space", "baz"}, nil}, + EndElement{Name{"space", "baz"}}, + EndElement{Name{"space", "foo"}}, + }, + want: ``, +}, { + desc: "impossible xmlns declaration", + toks: []Token{ + StartElement{Name{"", "foo"}, []Attr{ + {Name{"", "xmlns"}, "space"}, + }}, + StartElement{Name{"space", "bar"}, []Attr{ + {Name{"space", "attr"}, "value"}, + }}, + }, + want: ``, +}} + +func TestEncodeToken(t *testing.T) { +loop: + for i, tt := range encodeTokenTests { + var buf bytes.Buffer + enc := NewEncoder(&buf) + var err error + for j, tok := range tt.toks { + err = enc.EncodeToken(tok) + if err != nil && j < len(tt.toks)-1 { + t.Errorf("#%d %s token #%d: %v", i, tt.desc, j, err) + continue loop + } + } + errorf := func(f string, a ...interface{}) { + t.Errorf("#%d %s token #%d:%s", i, tt.desc, len(tt.toks)-1, fmt.Sprintf(f, a...)) + } + switch { + case tt.err != "" && err == nil: + errorf(" expected error; got none") + continue + case tt.err == "" && err != nil: + errorf(" got error: %v", err) + continue + case tt.err != "" && err != nil && tt.err != err.Error(): + errorf(" error mismatch; got %v, want %v", err, tt.err) + continue + } + if err := enc.Flush(); err != nil { + errorf(" %v", err) + continue + } + if got := buf.String(); got != tt.want { + errorf("\ngot %v\nwant %v", got, tt.want) + continue + } + } +} + +func TestProcInstEncodeToken(t *testing.T) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + + if err := enc.EncodeToken(ProcInst{"xml", []byte("Instruction")}); err != nil { + t.Fatalf("enc.EncodeToken: expected to be able to encode xml target ProcInst as first token, %s", err) + } + + if err := enc.EncodeToken(ProcInst{"Target", []byte("Instruction")}); err != nil { + t.Fatalf("enc.EncodeToken: expected to be able to add non-xml target ProcInst") + } + + if err := enc.EncodeToken(ProcInst{"xml", []byte("Instruction")}); err == nil { + t.Fatalf("enc.EncodeToken: expected to not be allowed to encode xml target ProcInst when not first token") + } +} + +func TestDecodeEncode(t *testing.T) { + var in, out bytes.Buffer + in.WriteString(` + + + +`) + dec := NewDecoder(&in) + enc := NewEncoder(&out) + for tok, err := dec.Token(); err == nil; tok, err = dec.Token() { + err = enc.EncodeToken(tok) + if err != nil { + t.Fatalf("enc.EncodeToken: Unable to encode token (%#v), %v", tok, err) + } + } +} + +// Issue 9796. Used to fail with GORACE="halt_on_error=1" -race. +func TestRace9796(t *testing.T) { + type A struct{} + type B struct { + C []A `xml:"X>Y"` + } + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + Marshal(B{[]A{{}}}) + wg.Done() + }() + } + wg.Wait() +} + +func TestIsValidDirective(t *testing.T) { + testOK := []string{ + "<>", + "< < > >", + "' '>' >", + " ]>", + " '<' ' doc ANY> ]>", + ">>> a < comment --> [ ] >", + } + testKO := []string{ + "<", + ">", + "", + "< > > < < >", + " -->", + "", + "'", + "", + } + for _, s := range testOK { + if !isValidDirective(Directive(s)) { + t.Errorf("Directive %q is expected to be valid", s) + } + } + for _, s := range testKO { + if isValidDirective(Directive(s)) { + t.Errorf("Directive %q is expected to be invalid", s) + } + } +} + +// Issue 11719. EncodeToken used to silently eat tokens with an invalid type. +func TestSimpleUseOfEncodeToken(t *testing.T) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + if err := enc.EncodeToken(&StartElement{Name: Name{"", "object1"}}); err == nil { + t.Errorf("enc.EncodeToken: pointer type should be rejected") + } + if err := enc.EncodeToken(&EndElement{Name: Name{"", "object1"}}); err == nil { + t.Errorf("enc.EncodeToken: pointer type should be rejected") + } + if err := enc.EncodeToken(StartElement{Name: Name{"", "object2"}}); err != nil { + t.Errorf("enc.EncodeToken: StartElement %s", err) + } + if err := enc.EncodeToken(EndElement{Name: Name{"", "object2"}}); err != nil { + t.Errorf("enc.EncodeToken: EndElement %s", err) + } + if err := enc.EncodeToken(Universe{}); err == nil { + t.Errorf("enc.EncodeToken: invalid type not caught") + } + if err := enc.Flush(); err != nil { + t.Errorf("enc.Flush: %s", err) + } + if buf.Len() == 0 { + t.Errorf("enc.EncodeToken: empty buffer") + } + want := "" + if buf.String() != want { + t.Errorf("enc.EncodeToken: expected %q; got %q", want, buf.String()) + } +} diff --git a/server/webdav/internal/xml/read.go b/server/webdav/internal/xml/read.go new file mode 100644 index 0000000000000000000000000000000000000000..bfaef6f17f1ad865c70041ef8c75d542e5235447 --- /dev/null +++ b/server/webdav/internal/xml/read.go @@ -0,0 +1,691 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import ( + "bytes" + "encoding" + "errors" + "fmt" + "reflect" + "strconv" + "strings" +) + +// BUG(rsc): Mapping between XML elements and data structures is inherently flawed: +// an XML element is an order-dependent collection of anonymous +// values, while a data structure is an order-independent collection +// of named values. +// See package json for a textual representation more suitable +// to data structures. + +// Unmarshal parses the XML-encoded data and stores the result in +// the value pointed to by v, which must be an arbitrary struct, +// slice, or string. Well-formed data that does not fit into v is +// discarded. +// +// Because Unmarshal uses the reflect package, it can only assign +// to exported (upper case) fields. Unmarshal uses a case-sensitive +// comparison to match XML element names to tag values and struct +// field names. +// +// Unmarshal maps an XML element to a struct using the following rules. +// In the rules, the tag of a field refers to the value associated with the +// key 'xml' in the struct field's tag (see the example above). +// +// - If the struct has a field of type []byte or string with tag +// ",innerxml", Unmarshal accumulates the raw XML nested inside the +// element in that field. The rest of the rules still apply. +// +// - If the struct has a field named XMLName of type xml.Name, +// Unmarshal records the element name in that field. +// +// - If the XMLName field has an associated tag of the form +// "name" or "namespace-URL name", the XML element must have +// the given name (and, optionally, name space) or else Unmarshal +// returns an error. +// +// - If the XML element has an attribute whose name matches a +// struct field name with an associated tag containing ",attr" or +// the explicit name in a struct field tag of the form "name,attr", +// Unmarshal records the attribute value in that field. +// +// - If the XML element contains character data, that data is +// accumulated in the first struct field that has tag ",chardata". +// The struct field may have type []byte or string. +// If there is no such field, the character data is discarded. +// +// - If the XML element contains comments, they are accumulated in +// the first struct field that has tag ",comment". The struct +// field may have type []byte or string. If there is no such +// field, the comments are discarded. +// +// - If the XML element contains a sub-element whose name matches +// the prefix of a tag formatted as "a" or "a>b>c", unmarshal +// will descend into the XML structure looking for elements with the +// given names, and will map the innermost elements to that struct +// field. A tag starting with ">" is equivalent to one starting +// with the field name followed by ">". +// +// - If the XML element contains a sub-element whose name matches +// a struct field's XMLName tag and the struct field has no +// explicit name tag as per the previous rule, unmarshal maps +// the sub-element to that struct field. +// +// - If the XML element contains a sub-element whose name matches a +// field without any mode flags (",attr", ",chardata", etc), Unmarshal +// maps the sub-element to that struct field. +// +// - If the XML element contains a sub-element that hasn't matched any +// of the above rules and the struct has a field with tag ",any", +// unmarshal maps the sub-element to that struct field. +// +// - An anonymous struct field is handled as if the fields of its +// value were part of the outer struct. +// +// - A struct field with tag "-" is never unmarshalled into. +// +// Unmarshal maps an XML element to a string or []byte by saving the +// concatenation of that element's character data in the string or +// []byte. The saved []byte is never nil. +// +// Unmarshal maps an attribute value to a string or []byte by saving +// the value in the string or slice. +// +// Unmarshal maps an XML element to a slice by extending the length of +// the slice and mapping the element to the newly created value. +// +// Unmarshal maps an XML element or attribute value to a bool by +// setting it to the boolean value represented by the string. +// +// Unmarshal maps an XML element or attribute value to an integer or +// floating-point field by setting the field to the result of +// interpreting the string value in decimal. There is no check for +// overflow. +// +// Unmarshal maps an XML element to an xml.Name by recording the +// element name. +// +// Unmarshal maps an XML element to a pointer by setting the pointer +// to a freshly allocated value and then mapping the element to that value. +func Unmarshal(data []byte, v interface{}) error { + return NewDecoder(bytes.NewReader(data)).Decode(v) +} + +// Decode works like xml.Unmarshal, except it reads the decoder +// stream to find the start element. +func (d *Decoder) Decode(v interface{}) error { + return d.DecodeElement(v, nil) +} + +// DecodeElement works like xml.Unmarshal except that it takes +// a pointer to the start XML element to decode into v. +// It is useful when a client reads some raw XML tokens itself +// but also wants to defer to Unmarshal for some elements. +func (d *Decoder) DecodeElement(v interface{}, start *StartElement) error { + val := reflect.ValueOf(v) + if val.Kind() != reflect.Ptr { + return errors.New("non-pointer passed to Unmarshal") + } + return d.unmarshal(val.Elem(), start) +} + +// An UnmarshalError represents an error in the unmarshalling process. +type UnmarshalError string + +func (e UnmarshalError) Error() string { return string(e) } + +// Unmarshaler is the interface implemented by objects that can unmarshal +// an XML element description of themselves. +// +// UnmarshalXML decodes a single XML element +// beginning with the given start element. +// If it returns an error, the outer call to Unmarshal stops and +// returns that error. +// UnmarshalXML must consume exactly one XML element. +// One common implementation strategy is to unmarshal into +// a separate value with a layout matching the expected XML +// using d.DecodeElement, and then to copy the data from +// that value into the receiver. +// Another common strategy is to use d.Token to process the +// XML object one token at a time. +// UnmarshalXML may not use d.RawToken. +type Unmarshaler interface { + UnmarshalXML(d *Decoder, start StartElement) error +} + +// UnmarshalerAttr is the interface implemented by objects that can unmarshal +// an XML attribute description of themselves. +// +// UnmarshalXMLAttr decodes a single XML attribute. +// If it returns an error, the outer call to Unmarshal stops and +// returns that error. +// UnmarshalXMLAttr is used only for struct fields with the +// "attr" option in the field tag. +type UnmarshalerAttr interface { + UnmarshalXMLAttr(attr Attr) error +} + +// receiverType returns the receiver type to use in an expression like "%s.MethodName". +func receiverType(val interface{}) string { + t := reflect.TypeOf(val) + if t.Name() != "" { + return t.String() + } + return "(" + t.String() + ")" +} + +// unmarshalInterface unmarshals a single XML element into val. +// start is the opening tag of the element. +func (p *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error { + // Record that decoder must stop at end tag corresponding to start. + p.pushEOF() + + p.unmarshalDepth++ + err := val.UnmarshalXML(p, *start) + p.unmarshalDepth-- + if err != nil { + p.popEOF() + return err + } + + if !p.popEOF() { + return fmt.Errorf("xml: %s.UnmarshalXML did not consume entire <%s> element", receiverType(val), start.Name.Local) + } + + return nil +} + +// unmarshalTextInterface unmarshals a single XML element into val. +// The chardata contained in the element (but not its children) +// is passed to the text unmarshaler. +func (p *Decoder) unmarshalTextInterface(val encoding.TextUnmarshaler, start *StartElement) error { + var buf []byte + depth := 1 + for depth > 0 { + t, err := p.Token() + if err != nil { + return err + } + switch t := t.(type) { + case CharData: + if depth == 1 { + buf = append(buf, t...) + } + case StartElement: + depth++ + case EndElement: + depth-- + } + } + return val.UnmarshalText(buf) +} + +// unmarshalAttr unmarshals a single XML attribute into val. +func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error { + if val.Kind() == reflect.Ptr { + if val.IsNil() { + val.Set(reflect.New(val.Type().Elem())) + } + val = val.Elem() + } + + if val.CanInterface() && val.Type().Implements(unmarshalerAttrType) { + // This is an unmarshaler with a non-pointer receiver, + // so it's likely to be incorrect, but we do what we're told. + return val.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr) + } + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(unmarshalerAttrType) { + return pv.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr) + } + } + + // Not an UnmarshalerAttr; try encoding.TextUnmarshaler. + if val.CanInterface() && val.Type().Implements(textUnmarshalerType) { + // This is an unmarshaler with a non-pointer receiver, + // so it's likely to be incorrect, but we do what we're told. + return val.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value)) + } + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) { + return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value)) + } + } + + copyValue(val, []byte(attr.Value)) + return nil +} + +var ( + unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() + unmarshalerAttrType = reflect.TypeOf((*UnmarshalerAttr)(nil)).Elem() + textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() +) + +// Unmarshal a single XML element into val. +func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { + // Find start element if we need it. + if start == nil { + for { + tok, err := p.Token() + if err != nil { + return err + } + if t, ok := tok.(StartElement); ok { + start = &t + break + } + } + } + + // Load value from interface, but only if the result will be + // usefully addressable. + if val.Kind() == reflect.Interface && !val.IsNil() { + e := val.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() { + val = e + } + } + + if val.Kind() == reflect.Ptr { + if val.IsNil() { + val.Set(reflect.New(val.Type().Elem())) + } + val = val.Elem() + } + + if val.CanInterface() && val.Type().Implements(unmarshalerType) { + // This is an unmarshaler with a non-pointer receiver, + // so it's likely to be incorrect, but we do what we're told. + return p.unmarshalInterface(val.Interface().(Unmarshaler), start) + } + + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(unmarshalerType) { + return p.unmarshalInterface(pv.Interface().(Unmarshaler), start) + } + } + + if val.CanInterface() && val.Type().Implements(textUnmarshalerType) { + return p.unmarshalTextInterface(val.Interface().(encoding.TextUnmarshaler), start) + } + + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) { + return p.unmarshalTextInterface(pv.Interface().(encoding.TextUnmarshaler), start) + } + } + + var ( + data []byte + saveData reflect.Value + comment []byte + saveComment reflect.Value + saveXML reflect.Value + saveXMLIndex int + saveXMLData []byte + saveAny reflect.Value + sv reflect.Value + tinfo *typeInfo + err error + ) + + switch v := val; v.Kind() { + default: + return errors.New("unknown type " + v.Type().String()) + + case reflect.Interface: + // TODO: For now, simply ignore the field. In the near + // future we may choose to unmarshal the start + // element on it, if not nil. + return p.Skip() + + case reflect.Slice: + typ := v.Type() + if typ.Elem().Kind() == reflect.Uint8 { + // []byte + saveData = v + break + } + + // Slice of element values. + // Grow slice. + n := v.Len() + if n >= v.Cap() { + ncap := 2 * n + if ncap < 4 { + ncap = 4 + } + new := reflect.MakeSlice(typ, n, ncap) + reflect.Copy(new, v) + v.Set(new) + } + v.SetLen(n + 1) + + // Recur to read element into slice. + if err := p.unmarshal(v.Index(n), start); err != nil { + v.SetLen(n) + return err + } + return nil + + case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String: + saveData = v + + case reflect.Struct: + typ := v.Type() + if typ == nameType { + v.Set(reflect.ValueOf(start.Name)) + break + } + + sv = v + tinfo, err = getTypeInfo(typ) + if err != nil { + return err + } + + // Validate and assign element name. + if tinfo.xmlname != nil { + finfo := tinfo.xmlname + if finfo.name != "" && finfo.name != start.Name.Local { + return UnmarshalError("expected element type <" + finfo.name + "> but have <" + start.Name.Local + ">") + } + if finfo.xmlns != "" && finfo.xmlns != start.Name.Space { + e := "expected element <" + finfo.name + "> in name space " + finfo.xmlns + " but have " + if start.Name.Space == "" { + e += "no name space" + } else { + e += start.Name.Space + } + return UnmarshalError(e) + } + fv := finfo.value(sv) + if _, ok := fv.Interface().(Name); ok { + fv.Set(reflect.ValueOf(start.Name)) + } + } + + // Assign attributes. + // Also, determine whether we need to save character data or comments. + for i := range tinfo.fields { + finfo := &tinfo.fields[i] + switch finfo.flags & fMode { + case fAttr: + strv := finfo.value(sv) + // Look for attribute. + for _, a := range start.Attr { + if a.Name.Local == finfo.name && (finfo.xmlns == "" || finfo.xmlns == a.Name.Space) { + if err := p.unmarshalAttr(strv, a); err != nil { + return err + } + break + } + } + + case fCharData: + if !saveData.IsValid() { + saveData = finfo.value(sv) + } + + case fComment: + if !saveComment.IsValid() { + saveComment = finfo.value(sv) + } + + case fAny, fAny | fElement: + if !saveAny.IsValid() { + saveAny = finfo.value(sv) + } + + case fInnerXml: + if !saveXML.IsValid() { + saveXML = finfo.value(sv) + if p.saved == nil { + saveXMLIndex = 0 + p.saved = new(bytes.Buffer) + } else { + saveXMLIndex = p.savedOffset() + } + } + } + } + } + + // Find end element. + // Process sub-elements along the way. +Loop: + for { + var savedOffset int + if saveXML.IsValid() { + savedOffset = p.savedOffset() + } + tok, err := p.Token() + if err != nil { + return err + } + switch t := tok.(type) { + case StartElement: + consumed := false + if sv.IsValid() { + consumed, err = p.unmarshalPath(tinfo, sv, nil, &t) + if err != nil { + return err + } + if !consumed && saveAny.IsValid() { + consumed = true + if err := p.unmarshal(saveAny, &t); err != nil { + return err + } + } + } + if !consumed { + if err := p.Skip(); err != nil { + return err + } + } + + case EndElement: + if saveXML.IsValid() { + saveXMLData = p.saved.Bytes()[saveXMLIndex:savedOffset] + if saveXMLIndex == 0 { + p.saved = nil + } + } + break Loop + + case CharData: + if saveData.IsValid() { + data = append(data, t...) + } + + case Comment: + if saveComment.IsValid() { + comment = append(comment, t...) + } + } + } + + if saveData.IsValid() && saveData.CanInterface() && saveData.Type().Implements(textUnmarshalerType) { + if err := saveData.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil { + return err + } + saveData = reflect.Value{} + } + + if saveData.IsValid() && saveData.CanAddr() { + pv := saveData.Addr() + if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) { + if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil { + return err + } + saveData = reflect.Value{} + } + } + + if err := copyValue(saveData, data); err != nil { + return err + } + + switch t := saveComment; t.Kind() { + case reflect.String: + t.SetString(string(comment)) + case reflect.Slice: + t.Set(reflect.ValueOf(comment)) + } + + switch t := saveXML; t.Kind() { + case reflect.String: + t.SetString(string(saveXMLData)) + case reflect.Slice: + t.Set(reflect.ValueOf(saveXMLData)) + } + + return nil +} + +func copyValue(dst reflect.Value, src []byte) (err error) { + dst0 := dst + + if dst.Kind() == reflect.Ptr { + if dst.IsNil() { + dst.Set(reflect.New(dst.Type().Elem())) + } + dst = dst.Elem() + } + + // Save accumulated data. + switch dst.Kind() { + case reflect.Invalid: + // Probably a comment. + default: + return errors.New("cannot unmarshal into " + dst0.Type().String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + itmp, err := strconv.ParseInt(string(src), 10, dst.Type().Bits()) + if err != nil { + return err + } + dst.SetInt(itmp) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + utmp, err := strconv.ParseUint(string(src), 10, dst.Type().Bits()) + if err != nil { + return err + } + dst.SetUint(utmp) + case reflect.Float32, reflect.Float64: + ftmp, err := strconv.ParseFloat(string(src), dst.Type().Bits()) + if err != nil { + return err + } + dst.SetFloat(ftmp) + case reflect.Bool: + value, err := strconv.ParseBool(strings.TrimSpace(string(src))) + if err != nil { + return err + } + dst.SetBool(value) + case reflect.String: + dst.SetString(string(src)) + case reflect.Slice: + if len(src) == 0 { + // non-nil to flag presence + src = []byte{} + } + dst.SetBytes(src) + } + return nil +} + +// unmarshalPath walks down an XML structure looking for wanted +// paths, and calls unmarshal on them. +// The consumed result tells whether XML elements have been consumed +// from the Decoder until start's matching end element, or if it's +// still untouched because start is uninteresting for sv's fields. +func (p *Decoder) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement) (consumed bool, err error) { + recurse := false +Loop: + for i := range tinfo.fields { + finfo := &tinfo.fields[i] + if finfo.flags&fElement == 0 || len(finfo.parents) < len(parents) || finfo.xmlns != "" && finfo.xmlns != start.Name.Space { + continue + } + for j := range parents { + if parents[j] != finfo.parents[j] { + continue Loop + } + } + if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local { + // It's a perfect match, unmarshal the field. + return true, p.unmarshal(finfo.value(sv), start) + } + if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local { + // It's a prefix for the field. Break and recurse + // since it's not ok for one field path to be itself + // the prefix for another field path. + recurse = true + + // We can reuse the same slice as long as we + // don't try to append to it. + parents = finfo.parents[:len(parents)+1] + break + } + } + if !recurse { + // We have no business with this element. + return false, nil + } + // The element is not a perfect match for any field, but one + // or more fields have the path to this element as a parent + // prefix. Recurse and attempt to match these. + for { + var tok Token + tok, err = p.Token() + if err != nil { + return true, err + } + switch t := tok.(type) { + case StartElement: + consumed2, err := p.unmarshalPath(tinfo, sv, parents, &t) + if err != nil { + return true, err + } + if !consumed2 { + if err := p.Skip(); err != nil { + return true, err + } + } + case EndElement: + return true, nil + } + } +} + +// Skip reads tokens until it has consumed the end element +// matching the most recent start element already consumed. +// It recurs if it encounters a start element, so it can be used to +// skip nested structures. +// It returns nil if it finds an end element matching the start +// element; otherwise it returns an error describing the problem. +func (d *Decoder) Skip() error { + for { + tok, err := d.Token() + if err != nil { + return err + } + switch tok.(type) { + case StartElement: + if err := d.Skip(); err != nil { + return err + } + case EndElement: + return nil + } + } +} diff --git a/server/webdav/internal/xml/read_test.go b/server/webdav/internal/xml/read_test.go new file mode 100644 index 0000000000000000000000000000000000000000..02f1e10c330add9ab1c30cfc648c0241094e3255 --- /dev/null +++ b/server/webdav/internal/xml/read_test.go @@ -0,0 +1,744 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import ( + "bytes" + "fmt" + "io" + "reflect" + "strings" + "testing" + "time" +) + +// Stripped down Atom feed data structures. + +func TestUnmarshalFeed(t *testing.T) { + var f Feed + if err := Unmarshal([]byte(atomFeedString), &f); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if !reflect.DeepEqual(f, atomFeed) { + t.Fatalf("have %#v\nwant %#v", f, atomFeed) + } +} + +// hget http://codereview.appspot.com/rss/mine/rsc +const atomFeedString = ` + +Code Review - My issueshttp://codereview.appspot.com/rietveld<>rietveld: an attempt at pubsubhubbub +2009-10-04T01:35:58+00:00email-address-removedurn:md5:134d9179c41f806be79b3a5f7877d19a + An attempt at adding pubsubhubbub support to Rietveld. +http://code.google.com/p/pubsubhubbub +http://code.google.com/p/rietveld/issues/detail?id=155 + +The server side of the protocol is trivial: + 1. add a &lt;link rel=&quot;hub&quot; href=&quot;hub-server&quot;&gt; tag to all + feeds that will be pubsubhubbubbed. + 2. every time one of those feeds changes, tell the hub + with a simple POST request. + +I have tested this by adding debug prints to a local hub +server and checking that the server got the right publish +requests. + +I can&#39;t quite get the server to work, but I think the bug +is not in my code. I think that the server expects to be +able to grab the feed and see the feed&#39;s actual URL in +the link rel=&quot;self&quot;, but the default value for that drops +the :port from the URL, and I cannot for the life of me +figure out how to get the Atom generator deep inside +django not to do that, or even where it is doing that, +or even what code is running to generate the Atom feed. +(I thought I knew but I added some assert False statements +and it kept running!) + +Ignoring that particular problem, I would appreciate +feedback on the right way to get the two values at +the top of feeds.py marked NOTE(rsc). + + +rietveld: correct tab handling +2009-10-03T23:02:17+00:00email-address-removedurn:md5:0a2a4f19bb815101f0ba2904aed7c35a + This fixes the buggy tab rendering that can be seen at +http://codereview.appspot.com/116075/diff/1/2 + +The fundamental problem was that the tab code was +not being told what column the text began in, so it +didn&#39;t know where to put the tab stops. Another problem +was that some of the code assumed that string byte +offsets were the same as column offsets, which is only +true if there are no tabs. + +In the process of fixing this, I cleaned up the arguments +to Fold and ExpandTabs and renamed them Break and +_ExpandTabs so that I could be sure that I found all the +call sites. I also wanted to verify that ExpandTabs was +not being used from outside intra_region_diff.py. + + + ` + +type Feed struct { + XMLName Name `xml:"http://www.w3.org/2005/Atom feed"` + Title string `xml:"title"` + Id string `xml:"id"` + Link []Link `xml:"link"` + Updated time.Time `xml:"updated,attr"` + Author Person `xml:"author"` + Entry []Entry `xml:"entry"` +} + +type Entry struct { + Title string `xml:"title"` + Id string `xml:"id"` + Link []Link `xml:"link"` + Updated time.Time `xml:"updated"` + Author Person `xml:"author"` + Summary Text `xml:"summary"` +} + +type Link struct { + Rel string `xml:"rel,attr,omitempty"` + Href string `xml:"href,attr"` +} + +type Person struct { + Name string `xml:"name"` + URI string `xml:"uri"` + Email string `xml:"email"` + InnerXML string `xml:",innerxml"` +} + +type Text struct { + Type string `xml:"type,attr,omitempty"` + Body string `xml:",chardata"` +} + +var atomFeed = Feed{ + XMLName: Name{"http://www.w3.org/2005/Atom", "feed"}, + Title: "Code Review - My issues", + Link: []Link{ + {Rel: "alternate", Href: "http://codereview.appspot.com/"}, + {Rel: "self", Href: "http://codereview.appspot.com/rss/mine/rsc"}, + }, + Id: "http://codereview.appspot.com/", + Updated: ParseTime("2009-10-04T01:35:58+00:00"), + Author: Person{ + Name: "rietveld<>", + InnerXML: "rietveld<>", + }, + Entry: []Entry{ + { + Title: "rietveld: an attempt at pubsubhubbub\n", + Link: []Link{ + {Rel: "alternate", Href: "http://codereview.appspot.com/126085"}, + }, + Updated: ParseTime("2009-10-04T01:35:58+00:00"), + Author: Person{ + Name: "email-address-removed", + InnerXML: "email-address-removed", + }, + Id: "urn:md5:134d9179c41f806be79b3a5f7877d19a", + Summary: Text{ + Type: "html", + Body: ` + An attempt at adding pubsubhubbub support to Rietveld. +http://code.google.com/p/pubsubhubbub +http://code.google.com/p/rietveld/issues/detail?id=155 + +The server side of the protocol is trivial: + 1. add a <link rel="hub" href="hub-server"> tag to all + feeds that will be pubsubhubbubbed. + 2. every time one of those feeds changes, tell the hub + with a simple POST request. + +I have tested this by adding debug prints to a local hub +server and checking that the server got the right publish +requests. + +I can't quite get the server to work, but I think the bug +is not in my code. I think that the server expects to be +able to grab the feed and see the feed's actual URL in +the link rel="self", but the default value for that drops +the :port from the URL, and I cannot for the life of me +figure out how to get the Atom generator deep inside +django not to do that, or even where it is doing that, +or even what code is running to generate the Atom feed. +(I thought I knew but I added some assert False statements +and it kept running!) + +Ignoring that particular problem, I would appreciate +feedback on the right way to get the two values at +the top of feeds.py marked NOTE(rsc). + + +`, + }, + }, + { + Title: "rietveld: correct tab handling\n", + Link: []Link{ + {Rel: "alternate", Href: "http://codereview.appspot.com/124106"}, + }, + Updated: ParseTime("2009-10-03T23:02:17+00:00"), + Author: Person{ + Name: "email-address-removed", + InnerXML: "email-address-removed", + }, + Id: "urn:md5:0a2a4f19bb815101f0ba2904aed7c35a", + Summary: Text{ + Type: "html", + Body: ` + This fixes the buggy tab rendering that can be seen at +http://codereview.appspot.com/116075/diff/1/2 + +The fundamental problem was that the tab code was +not being told what column the text began in, so it +didn't know where to put the tab stops. Another problem +was that some of the code assumed that string byte +offsets were the same as column offsets, which is only +true if there are no tabs. + +In the process of fixing this, I cleaned up the arguments +to Fold and ExpandTabs and renamed them Break and +_ExpandTabs so that I could be sure that I found all the +call sites. I also wanted to verify that ExpandTabs was +not being used from outside intra_region_diff.py. + + +`, + }, + }, + }, +} + +const pathTestString = ` + + 1 + + + A + + + B + + + C + D + + <_> + E + + + 2 + +` + +type PathTestItem struct { + Value string +} + +type PathTestA struct { + Items []PathTestItem `xml:">Item1"` + Before, After string +} + +type PathTestB struct { + Other []PathTestItem `xml:"Items>Item1"` + Before, After string +} + +type PathTestC struct { + Values1 []string `xml:"Items>Item1>Value"` + Values2 []string `xml:"Items>Item2>Value"` + Before, After string +} + +type PathTestSet struct { + Item1 []PathTestItem +} + +type PathTestD struct { + Other PathTestSet `xml:"Items"` + Before, After string +} + +type PathTestE struct { + Underline string `xml:"Items>_>Value"` + Before, After string +} + +var pathTests = []interface{}{ + &PathTestA{Items: []PathTestItem{{"A"}, {"D"}}, Before: "1", After: "2"}, + &PathTestB{Other: []PathTestItem{{"A"}, {"D"}}, Before: "1", After: "2"}, + &PathTestC{Values1: []string{"A", "C", "D"}, Values2: []string{"B"}, Before: "1", After: "2"}, + &PathTestD{Other: PathTestSet{Item1: []PathTestItem{{"A"}, {"D"}}}, Before: "1", After: "2"}, + &PathTestE{Underline: "E", Before: "1", After: "2"}, +} + +func TestUnmarshalPaths(t *testing.T) { + for _, pt := range pathTests { + v := reflect.New(reflect.TypeOf(pt).Elem()).Interface() + if err := Unmarshal([]byte(pathTestString), v); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if !reflect.DeepEqual(v, pt) { + t.Fatalf("have %#v\nwant %#v", v, pt) + } + } +} + +type BadPathTestA struct { + First string `xml:"items>item1"` + Other string `xml:"items>item2"` + Second string `xml:"items"` +} + +type BadPathTestB struct { + Other string `xml:"items>item2>value"` + First string `xml:"items>item1"` + Second string `xml:"items>item1>value"` +} + +type BadPathTestC struct { + First string + Second string `xml:"First"` +} + +type BadPathTestD struct { + BadPathEmbeddedA + BadPathEmbeddedB +} + +type BadPathEmbeddedA struct { + First string +} + +type BadPathEmbeddedB struct { + Second string `xml:"First"` +} + +var badPathTests = []struct { + v, e interface{} +}{ + {&BadPathTestA{}, &TagPathError{reflect.TypeOf(BadPathTestA{}), "First", "items>item1", "Second", "items"}}, + {&BadPathTestB{}, &TagPathError{reflect.TypeOf(BadPathTestB{}), "First", "items>item1", "Second", "items>item1>value"}}, + {&BadPathTestC{}, &TagPathError{reflect.TypeOf(BadPathTestC{}), "First", "", "Second", "First"}}, + {&BadPathTestD{}, &TagPathError{reflect.TypeOf(BadPathTestD{}), "First", "", "Second", "First"}}, +} + +func TestUnmarshalBadPaths(t *testing.T) { + for _, tt := range badPathTests { + err := Unmarshal([]byte(pathTestString), tt.v) + if !reflect.DeepEqual(err, tt.e) { + t.Fatalf("Unmarshal with %#v didn't fail properly:\nhave %#v,\nwant %#v", tt.v, err, tt.e) + } + } +} + +const OK = "OK" +const withoutNameTypeData = ` + +` + +type TestThree struct { + XMLName Name `xml:"Test3"` + Attr string `xml:",attr"` +} + +func TestUnmarshalWithoutNameType(t *testing.T) { + var x TestThree + if err := Unmarshal([]byte(withoutNameTypeData), &x); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if x.Attr != OK { + t.Fatalf("have %v\nwant %v", x.Attr, OK) + } +} + +func TestUnmarshalAttr(t *testing.T) { + type ParamVal struct { + Int int `xml:"int,attr"` + } + + type ParamPtr struct { + Int *int `xml:"int,attr"` + } + + type ParamStringPtr struct { + Int *string `xml:"int,attr"` + } + + x := []byte(``) + + p1 := &ParamPtr{} + if err := Unmarshal(x, p1); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if p1.Int == nil { + t.Fatalf("Unmarshal failed in to *int field") + } else if *p1.Int != 1 { + t.Fatalf("Unmarshal with %s failed:\nhave %#v,\n want %#v", x, p1.Int, 1) + } + + p2 := &ParamVal{} + if err := Unmarshal(x, p2); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if p2.Int != 1 { + t.Fatalf("Unmarshal with %s failed:\nhave %#v,\n want %#v", x, p2.Int, 1) + } + + p3 := &ParamStringPtr{} + if err := Unmarshal(x, p3); err != nil { + t.Fatalf("Unmarshal: %s", err) + } + if p3.Int == nil { + t.Fatalf("Unmarshal failed in to *string field") + } else if *p3.Int != "1" { + t.Fatalf("Unmarshal with %s failed:\nhave %#v,\n want %#v", x, p3.Int, 1) + } +} + +type Tables struct { + HTable string `xml:"http://www.w3.org/TR/html4/ table"` + FTable string `xml:"http://www.w3schools.com/furniture table"` +} + +var tables = []struct { + xml string + tab Tables + ns string +}{ + { + xml: `` + + `hello
` + + `world
` + + `
`, + tab: Tables{"hello", "world"}, + }, + { + xml: `` + + `world
` + + `hello
` + + `
`, + tab: Tables{"hello", "world"}, + }, + { + xml: `` + + `world` + + `hello` + + ``, + tab: Tables{"hello", "world"}, + }, + { + xml: `` + + `bogus
` + + `
`, + tab: Tables{}, + }, + { + xml: `` + + `only
` + + `
`, + tab: Tables{HTable: "only"}, + ns: "http://www.w3.org/TR/html4/", + }, + { + xml: `` + + `only
` + + `
`, + tab: Tables{FTable: "only"}, + ns: "http://www.w3schools.com/furniture", + }, + { + xml: `` + + `only
` + + `
`, + tab: Tables{}, + ns: "something else entirely", + }, +} + +func TestUnmarshalNS(t *testing.T) { + for i, tt := range tables { + var dst Tables + var err error + if tt.ns != "" { + d := NewDecoder(strings.NewReader(tt.xml)) + d.DefaultSpace = tt.ns + err = d.Decode(&dst) + } else { + err = Unmarshal([]byte(tt.xml), &dst) + } + if err != nil { + t.Errorf("#%d: Unmarshal: %v", i, err) + continue + } + want := tt.tab + if dst != want { + t.Errorf("#%d: dst=%+v, want %+v", i, dst, want) + } + } +} + +func TestRoundTrip(t *testing.T) { + // From issue 7535 + const s = `` + in := bytes.NewBufferString(s) + for i := 0; i < 10; i++ { + out := &bytes.Buffer{} + d := NewDecoder(in) + e := NewEncoder(out) + + for { + t, err := d.Token() + if err == io.EOF { + break + } + if err != nil { + fmt.Println("failed:", err) + return + } + e.EncodeToken(t) + } + e.Flush() + in = out + } + if got := in.String(); got != s { + t.Errorf("have: %q\nwant: %q\n", got, s) + } +} + +func TestMarshalNS(t *testing.T) { + dst := Tables{"hello", "world"} + data, err := Marshal(&dst) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + want := `hello
world
` + str := string(data) + if str != want { + t.Errorf("have: %q\nwant: %q\n", str, want) + } +} + +type TableAttrs struct { + TAttr TAttr +} + +type TAttr struct { + HTable string `xml:"http://www.w3.org/TR/html4/ table,attr"` + FTable string `xml:"http://www.w3schools.com/furniture table,attr"` + Lang string `xml:"http://www.w3.org/XML/1998/namespace lang,attr,omitempty"` + Other1 string `xml:"http://golang.org/xml/ other,attr,omitempty"` + Other2 string `xml:"http://golang.org/xmlfoo/ other,attr,omitempty"` + Other3 string `xml:"http://golang.org/json/ other,attr,omitempty"` + Other4 string `xml:"http://golang.org/2/json/ other,attr,omitempty"` +} + +var tableAttrs = []struct { + xml string + tab TableAttrs + ns string +}{ + { + xml: ``, + tab: TableAttrs{TAttr{HTable: "hello", FTable: "world"}}, + }, + { + xml: ``, + tab: TableAttrs{TAttr{HTable: "hello", FTable: "world"}}, + }, + { + xml: ``, + tab: TableAttrs{TAttr{HTable: "hello", FTable: "world"}}, + }, + { + // Default space does not apply to attribute names. + xml: ``, + tab: TableAttrs{TAttr{HTable: "hello", FTable: ""}}, + }, + { + // Default space does not apply to attribute names. + xml: ``, + tab: TableAttrs{TAttr{HTable: "", FTable: "world"}}, + }, + { + xml: ``, + tab: TableAttrs{}, + }, + { + // Default space does not apply to attribute names. + xml: ``, + tab: TableAttrs{TAttr{HTable: "hello", FTable: ""}}, + ns: "http://www.w3schools.com/furniture", + }, + { + // Default space does not apply to attribute names. + xml: ``, + tab: TableAttrs{TAttr{HTable: "", FTable: "world"}}, + ns: "http://www.w3.org/TR/html4/", + }, + { + xml: ``, + tab: TableAttrs{}, + ns: "something else entirely", + }, +} + +func TestUnmarshalNSAttr(t *testing.T) { + for i, tt := range tableAttrs { + var dst TableAttrs + var err error + if tt.ns != "" { + d := NewDecoder(strings.NewReader(tt.xml)) + d.DefaultSpace = tt.ns + err = d.Decode(&dst) + } else { + err = Unmarshal([]byte(tt.xml), &dst) + } + if err != nil { + t.Errorf("#%d: Unmarshal: %v", i, err) + continue + } + want := tt.tab + if dst != want { + t.Errorf("#%d: dst=%+v, want %+v", i, dst, want) + } + } +} + +func TestMarshalNSAttr(t *testing.T) { + src := TableAttrs{TAttr{"hello", "world", "en_US", "other1", "other2", "other3", "other4"}} + data, err := Marshal(&src) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + want := `` + str := string(data) + if str != want { + t.Errorf("Marshal:\nhave: %#q\nwant: %#q\n", str, want) + } + + var dst TableAttrs + if err := Unmarshal(data, &dst); err != nil { + t.Errorf("Unmarshal: %v", err) + } + + if dst != src { + t.Errorf("Unmarshal = %q, want %q", dst, src) + } +} + +type MyCharData struct { + body string +} + +func (m *MyCharData) UnmarshalXML(d *Decoder, start StartElement) error { + for { + t, err := d.Token() + if err == io.EOF { // found end of element + break + } + if err != nil { + return err + } + if char, ok := t.(CharData); ok { + m.body += string(char) + } + } + return nil +} + +var _ Unmarshaler = (*MyCharData)(nil) + +func (m *MyCharData) UnmarshalXMLAttr(attr Attr) error { + panic("must not call") +} + +type MyAttr struct { + attr string +} + +func (m *MyAttr) UnmarshalXMLAttr(attr Attr) error { + m.attr = attr.Value + return nil +} + +var _ UnmarshalerAttr = (*MyAttr)(nil) + +type MyStruct struct { + Data *MyCharData + Attr *MyAttr `xml:",attr"` + + Data2 MyCharData + Attr2 MyAttr `xml:",attr"` +} + +func TestUnmarshaler(t *testing.T) { + xml := ` + + hello world + howdy world + + ` + + var m MyStruct + if err := Unmarshal([]byte(xml), &m); err != nil { + t.Fatal(err) + } + + if m.Data == nil || m.Attr == nil || m.Data.body != "hello world" || m.Attr.attr != "attr1" || m.Data2.body != "howdy world" || m.Attr2.attr != "attr2" { + t.Errorf("m=%#+v\n", m) + } +} + +type Pea struct { + Cotelydon string +} + +type Pod struct { + Pea interface{} `xml:"Pea"` +} + +// https://golang.org/issue/6836 +func TestUnmarshalIntoInterface(t *testing.T) { + pod := new(Pod) + pod.Pea = new(Pea) + xml := `Green stuff` + err := Unmarshal([]byte(xml), pod) + if err != nil { + t.Fatalf("failed to unmarshal %q: %v", xml, err) + } + pea, ok := pod.Pea.(*Pea) + if !ok { + t.Fatalf("unmarshalled into wrong type: have %T want *Pea", pod.Pea) + } + have, want := pea.Cotelydon, "Green stuff" + if have != want { + t.Errorf("failed to unmarshal into interface, have %q want %q", have, want) + } +} diff --git a/server/webdav/internal/xml/typeinfo.go b/server/webdav/internal/xml/typeinfo.go new file mode 100644 index 0000000000000000000000000000000000000000..fdde288bc37cc1edfd862a67c3b60701130338fb --- /dev/null +++ b/server/webdav/internal/xml/typeinfo.go @@ -0,0 +1,371 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package xml + +import ( + "fmt" + "reflect" + "strings" + "sync" +) + +// typeInfo holds details for the xml representation of a type. +type typeInfo struct { + xmlname *fieldInfo + fields []fieldInfo +} + +// fieldInfo holds details for the xml representation of a single field. +type fieldInfo struct { + idx []int + name string + xmlns string + flags fieldFlags + parents []string +} + +type fieldFlags int + +const ( + fElement fieldFlags = 1 << iota + fAttr + fCharData + fInnerXml + fComment + fAny + + fOmitEmpty + + fMode = fElement | fAttr | fCharData | fInnerXml | fComment | fAny +) + +var tinfoMap = make(map[reflect.Type]*typeInfo) +var tinfoLock sync.RWMutex + +var nameType = reflect.TypeOf(Name{}) + +// getTypeInfo returns the typeInfo structure with details necessary +// for marshalling and unmarshalling typ. +func getTypeInfo(typ reflect.Type) (*typeInfo, error) { + tinfoLock.RLock() + tinfo, ok := tinfoMap[typ] + tinfoLock.RUnlock() + if ok { + return tinfo, nil + } + tinfo = &typeInfo{} + if typ.Kind() == reflect.Struct && typ != nameType { + n := typ.NumField() + for i := 0; i < n; i++ { + f := typ.Field(i) + if f.PkgPath != "" || f.Tag.Get("xml") == "-" { + continue // Private field + } + + // For embedded structs, embed its fields. + if f.Anonymous { + t := f.Type + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() == reflect.Struct { + inner, err := getTypeInfo(t) + if err != nil { + return nil, err + } + if tinfo.xmlname == nil { + tinfo.xmlname = inner.xmlname + } + for _, finfo := range inner.fields { + finfo.idx = append([]int{i}, finfo.idx...) + if err := addFieldInfo(typ, tinfo, &finfo); err != nil { + return nil, err + } + } + continue + } + } + + finfo, err := structFieldInfo(typ, &f) + if err != nil { + return nil, err + } + + if f.Name == "XMLName" { + tinfo.xmlname = finfo + continue + } + + // Add the field if it doesn't conflict with other fields. + if err := addFieldInfo(typ, tinfo, finfo); err != nil { + return nil, err + } + } + } + tinfoLock.Lock() + tinfoMap[typ] = tinfo + tinfoLock.Unlock() + return tinfo, nil +} + +// structFieldInfo builds and returns a fieldInfo for f. +func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, error) { + finfo := &fieldInfo{idx: f.Index} + + // Split the tag from the xml namespace if necessary. + tag := f.Tag.Get("xml") + if i := strings.Index(tag, " "); i >= 0 { + finfo.xmlns, tag = tag[:i], tag[i+1:] + } + + // Parse flags. + tokens := strings.Split(tag, ",") + if len(tokens) == 1 { + finfo.flags = fElement + } else { + tag = tokens[0] + for _, flag := range tokens[1:] { + switch flag { + case "attr": + finfo.flags |= fAttr + case "chardata": + finfo.flags |= fCharData + case "innerxml": + finfo.flags |= fInnerXml + case "comment": + finfo.flags |= fComment + case "any": + finfo.flags |= fAny + case "omitempty": + finfo.flags |= fOmitEmpty + } + } + + // Validate the flags used. + valid := true + switch mode := finfo.flags & fMode; mode { + case 0: + finfo.flags |= fElement + case fAttr, fCharData, fInnerXml, fComment, fAny: + if f.Name == "XMLName" || tag != "" && mode != fAttr { + valid = false + } + default: + // This will also catch multiple modes in a single field. + valid = false + } + if finfo.flags&fMode == fAny { + finfo.flags |= fElement + } + if finfo.flags&fOmitEmpty != 0 && finfo.flags&(fElement|fAttr) == 0 { + valid = false + } + if !valid { + return nil, fmt.Errorf("xml: invalid tag in field %s of type %s: %q", + f.Name, typ, f.Tag.Get("xml")) + } + } + + // Use of xmlns without a name is not allowed. + if finfo.xmlns != "" && tag == "" { + return nil, fmt.Errorf("xml: namespace without name in field %s of type %s: %q", + f.Name, typ, f.Tag.Get("xml")) + } + + if f.Name == "XMLName" { + // The XMLName field records the XML element name. Don't + // process it as usual because its name should default to + // empty rather than to the field name. + finfo.name = tag + return finfo, nil + } + + if tag == "" { + // If the name part of the tag is completely empty, get + // default from XMLName of underlying struct if feasible, + // or field name otherwise. + if xmlname := lookupXMLName(f.Type); xmlname != nil { + finfo.xmlns, finfo.name = xmlname.xmlns, xmlname.name + } else { + finfo.name = f.Name + } + return finfo, nil + } + + if finfo.xmlns == "" && finfo.flags&fAttr == 0 { + // If it's an element no namespace specified, get the default + // from the XMLName of enclosing struct if possible. + if xmlname := lookupXMLName(typ); xmlname != nil { + finfo.xmlns = xmlname.xmlns + } + } + + // Prepare field name and parents. + parents := strings.Split(tag, ">") + if parents[0] == "" { + parents[0] = f.Name + } + if parents[len(parents)-1] == "" { + return nil, fmt.Errorf("xml: trailing '>' in field %s of type %s", f.Name, typ) + } + finfo.name = parents[len(parents)-1] + if len(parents) > 1 { + if (finfo.flags & fElement) == 0 { + return nil, fmt.Errorf("xml: %s chain not valid with %s flag", tag, strings.Join(tokens[1:], ",")) + } + finfo.parents = parents[:len(parents)-1] + } + + // If the field type has an XMLName field, the names must match + // so that the behavior of both marshalling and unmarshalling + // is straightforward and unambiguous. + if finfo.flags&fElement != 0 { + ftyp := f.Type + xmlname := lookupXMLName(ftyp) + if xmlname != nil && xmlname.name != finfo.name { + return nil, fmt.Errorf("xml: name %q in tag of %s.%s conflicts with name %q in %s.XMLName", + finfo.name, typ, f.Name, xmlname.name, ftyp) + } + } + return finfo, nil +} + +// lookupXMLName returns the fieldInfo for typ's XMLName field +// in case it exists and has a valid xml field tag, otherwise +// it returns nil. +func lookupXMLName(typ reflect.Type) (xmlname *fieldInfo) { + for typ.Kind() == reflect.Ptr { + typ = typ.Elem() + } + if typ.Kind() != reflect.Struct { + return nil + } + for i, n := 0, typ.NumField(); i < n; i++ { + f := typ.Field(i) + if f.Name != "XMLName" { + continue + } + finfo, err := structFieldInfo(typ, &f) + if finfo.name != "" && err == nil { + return finfo + } + // Also consider errors as a non-existent field tag + // and let getTypeInfo itself report the error. + break + } + return nil +} + +func min(a, b int) int { + if a <= b { + return a + } + return b +} + +// addFieldInfo adds finfo to tinfo.fields if there are no +// conflicts, or if conflicts arise from previous fields that were +// obtained from deeper embedded structures than finfo. In the latter +// case, the conflicting entries are dropped. +// A conflict occurs when the path (parent + name) to a field is +// itself a prefix of another path, or when two paths match exactly. +// It is okay for field paths to share a common, shorter prefix. +func addFieldInfo(typ reflect.Type, tinfo *typeInfo, newf *fieldInfo) error { + var conflicts []int +Loop: + // First, figure all conflicts. Most working code will have none. + for i := range tinfo.fields { + oldf := &tinfo.fields[i] + if oldf.flags&fMode != newf.flags&fMode { + continue + } + if oldf.xmlns != "" && newf.xmlns != "" && oldf.xmlns != newf.xmlns { + continue + } + minl := min(len(newf.parents), len(oldf.parents)) + for p := 0; p < minl; p++ { + if oldf.parents[p] != newf.parents[p] { + continue Loop + } + } + if len(oldf.parents) > len(newf.parents) { + if oldf.parents[len(newf.parents)] == newf.name { + conflicts = append(conflicts, i) + } + } else if len(oldf.parents) < len(newf.parents) { + if newf.parents[len(oldf.parents)] == oldf.name { + conflicts = append(conflicts, i) + } + } else { + if newf.name == oldf.name { + conflicts = append(conflicts, i) + } + } + } + // Without conflicts, add the new field and return. + if conflicts == nil { + tinfo.fields = append(tinfo.fields, *newf) + return nil + } + + // If any conflict is shallower, ignore the new field. + // This matches the Go field resolution on embedding. + for _, i := range conflicts { + if len(tinfo.fields[i].idx) < len(newf.idx) { + return nil + } + } + + // Otherwise, if any of them is at the same depth level, it's an error. + for _, i := range conflicts { + oldf := &tinfo.fields[i] + if len(oldf.idx) == len(newf.idx) { + f1 := typ.FieldByIndex(oldf.idx) + f2 := typ.FieldByIndex(newf.idx) + return &TagPathError{typ, f1.Name, f1.Tag.Get("xml"), f2.Name, f2.Tag.Get("xml")} + } + } + + // Otherwise, the new field is shallower, and thus takes precedence, + // so drop the conflicting fields from tinfo and append the new one. + for c := len(conflicts) - 1; c >= 0; c-- { + i := conflicts[c] + copy(tinfo.fields[i:], tinfo.fields[i+1:]) + tinfo.fields = tinfo.fields[:len(tinfo.fields)-1] + } + tinfo.fields = append(tinfo.fields, *newf) + return nil +} + +// A TagPathError represents an error in the unmarshalling process +// caused by the use of field tags with conflicting paths. +type TagPathError struct { + Struct reflect.Type + Field1, Tag1 string + Field2, Tag2 string +} + +func (e *TagPathError) Error() string { + return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2) +} + +// value returns v's field value corresponding to finfo. +// It's equivalent to v.FieldByIndex(finfo.idx), but initializes +// and dereferences pointers as necessary. +func (finfo *fieldInfo) value(v reflect.Value) reflect.Value { + for i, x := range finfo.idx { + if i > 0 { + t := v.Type() + if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + } + v = v.Field(x) + } + return v +} diff --git a/server/webdav/internal/xml/xml.go b/server/webdav/internal/xml/xml.go new file mode 100644 index 0000000000000000000000000000000000000000..7d88dac7b35763a3dda4a2d9982f5cd152c8cc57 --- /dev/null +++ b/server/webdav/internal/xml/xml.go @@ -0,0 +1,1998 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package xml implements a simple XML 1.0 parser that +// understands XML name spaces. +package xml + +// References: +// Annotated XML spec: http://www.xml.com/axml/testaxml.htm +// XML name spaces: http://www.w3.org/TR/REC-xml-names/ + +// TODO(rsc): +// Test error handling. + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "strconv" + "strings" + "unicode" + "unicode/utf8" +) + +// A SyntaxError represents a syntax error in the XML input stream. +type SyntaxError struct { + Msg string + Line int +} + +func (e *SyntaxError) Error() string { + return "XML syntax error on line " + strconv.Itoa(e.Line) + ": " + e.Msg +} + +// A Name represents an XML name (Local) annotated with a name space +// identifier (Space). In tokens returned by Decoder.Token, the Space +// identifier is given as a canonical URL, not the short prefix used in +// the document being parsed. +// +// As a special case, XML namespace declarations will use the literal +// string "xmlns" for the Space field instead of the fully resolved URL. +// See Encoder.EncodeToken for more information on namespace encoding +// behaviour. +type Name struct { + Space, Local string +} + +// isNamespace reports whether the name is a namespace-defining name. +func (name Name) isNamespace() bool { + return name.Local == "xmlns" || name.Space == "xmlns" +} + +// An Attr represents an attribute in an XML element (Name=Value). +type Attr struct { + Name Name + Value string +} + +// A Token is an interface holding one of the token types: +// StartElement, EndElement, CharData, Comment, ProcInst, or Directive. +type Token interface{} + +// A StartElement represents an XML start element. +type StartElement struct { + Name Name + Attr []Attr +} + +func (e StartElement) Copy() StartElement { + attrs := make([]Attr, len(e.Attr)) + copy(attrs, e.Attr) + e.Attr = attrs + return e +} + +// End returns the corresponding XML end element. +func (e StartElement) End() EndElement { + return EndElement{e.Name} +} + +// setDefaultNamespace sets the namespace of the element +// as the default for all elements contained within it. +func (e *StartElement) setDefaultNamespace() { + if e.Name.Space == "" { + // If there's no namespace on the element, don't + // set the default. Strictly speaking this might be wrong, as + // we can't tell if the element had no namespace set + // or was just using the default namespace. + return + } + // Don't add a default name space if there's already one set. + for _, attr := range e.Attr { + if attr.Name.Space == "" && attr.Name.Local == "xmlns" { + return + } + } + e.Attr = append(e.Attr, Attr{ + Name: Name{ + Local: "xmlns", + }, + Value: e.Name.Space, + }) +} + +// An EndElement represents an XML end element. +type EndElement struct { + Name Name +} + +// A CharData represents XML character data (raw text), +// in which XML escape sequences have been replaced by +// the characters they represent. +type CharData []byte + +func makeCopy(b []byte) []byte { + b1 := make([]byte, len(b)) + copy(b1, b) + return b1 +} + +func (c CharData) Copy() CharData { return CharData(makeCopy(c)) } + +// A Comment represents an XML comment of the form . +// The bytes do not include the comment markers. +type Comment []byte + +func (c Comment) Copy() Comment { return Comment(makeCopy(c)) } + +// A ProcInst represents an XML processing instruction of the form +type ProcInst struct { + Target string + Inst []byte +} + +func (p ProcInst) Copy() ProcInst { + p.Inst = makeCopy(p.Inst) + return p +} + +// A Directive represents an XML directive of the form . +// The bytes do not include the markers. +type Directive []byte + +func (d Directive) Copy() Directive { return Directive(makeCopy(d)) } + +// CopyToken returns a copy of a Token. +func CopyToken(t Token) Token { + switch v := t.(type) { + case CharData: + return v.Copy() + case Comment: + return v.Copy() + case Directive: + return v.Copy() + case ProcInst: + return v.Copy() + case StartElement: + return v.Copy() + } + return t +} + +// A Decoder represents an XML parser reading a particular input stream. +// The parser assumes that its input is encoded in UTF-8. +type Decoder struct { + // Strict defaults to true, enforcing the requirements + // of the XML specification. + // If set to false, the parser allows input containing common + // mistakes: + // * If an element is missing an end tag, the parser invents + // end tags as necessary to keep the return values from Token + // properly balanced. + // * In attribute values and character data, unknown or malformed + // character entities (sequences beginning with &) are left alone. + // + // Setting: + // + // d.Strict = false; + // d.AutoClose = HTMLAutoClose; + // d.Entity = HTMLEntity + // + // creates a parser that can handle typical HTML. + // + // Strict mode does not enforce the requirements of the XML name spaces TR. + // In particular it does not reject name space tags using undefined prefixes. + // Such tags are recorded with the unknown prefix as the name space URL. + Strict bool + + // When Strict == false, AutoClose indicates a set of elements to + // consider closed immediately after they are opened, regardless + // of whether an end element is present. + AutoClose []string + + // Entity can be used to map non-standard entity names to string replacements. + // The parser behaves as if these standard mappings are present in the map, + // regardless of the actual map content: + // + // "lt": "<", + // "gt": ">", + // "amp": "&", + // "apos": "'", + // "quot": `"`, + Entity map[string]string + + // CharsetReader, if non-nil, defines a function to generate + // charset-conversion readers, converting from the provided + // non-UTF-8 charset into UTF-8. If CharsetReader is nil or + // returns an error, parsing stops with an error. One of the + // the CharsetReader's result values must be non-nil. + CharsetReader func(charset string, input io.Reader) (io.Reader, error) + + // DefaultSpace sets the default name space used for unadorned tags, + // as if the entire XML stream were wrapped in an element containing + // the attribute xmlns="DefaultSpace". + DefaultSpace string + + r io.ByteReader + buf bytes.Buffer + saved *bytes.Buffer + stk *stack + free *stack + needClose bool + toClose Name + nextToken Token + nextByte int + ns map[string]string + err error + line int + offset int64 + unmarshalDepth int +} + +// NewDecoder creates a new XML parser reading from r. +// If r does not implement io.ByteReader, NewDecoder will +// do its own buffering. +func NewDecoder(r io.Reader) *Decoder { + d := &Decoder{ + ns: make(map[string]string), + nextByte: -1, + line: 1, + Strict: true, + } + d.switchToReader(r) + return d +} + +// Token returns the next XML token in the input stream. +// At the end of the input stream, Token returns nil, io.EOF. +// +// Slices of bytes in the returned token data refer to the +// parser's internal buffer and remain valid only until the next +// call to Token. To acquire a copy of the bytes, call CopyToken +// or the token's Copy method. +// +// Token expands self-closing elements such as
+// into separate start and end elements returned by successive calls. +// +// Token guarantees that the StartElement and EndElement +// tokens it returns are properly nested and matched: +// if Token encounters an unexpected end element, +// it will return an error. +// +// Token implements XML name spaces as described by +// http://www.w3.org/TR/REC-xml-names/. Each of the +// Name structures contained in the Token has the Space +// set to the URL identifying its name space when known. +// If Token encounters an unrecognized name space prefix, +// it uses the prefix as the Space rather than report an error. +func (d *Decoder) Token() (t Token, err error) { + if d.stk != nil && d.stk.kind == stkEOF { + err = io.EOF + return + } + if d.nextToken != nil { + t = d.nextToken + d.nextToken = nil + } else if t, err = d.rawToken(); err != nil { + return + } + + if !d.Strict { + if t1, ok := d.autoClose(t); ok { + d.nextToken = t + t = t1 + } + } + switch t1 := t.(type) { + case StartElement: + // In XML name spaces, the translations listed in the + // attributes apply to the element name and + // to the other attribute names, so process + // the translations first. + for _, a := range t1.Attr { + if a.Name.Space == "xmlns" { + v, ok := d.ns[a.Name.Local] + d.pushNs(a.Name.Local, v, ok) + d.ns[a.Name.Local] = a.Value + } + if a.Name.Space == "" && a.Name.Local == "xmlns" { + // Default space for untagged names + v, ok := d.ns[""] + d.pushNs("", v, ok) + d.ns[""] = a.Value + } + } + + d.translate(&t1.Name, true) + for i := range t1.Attr { + d.translate(&t1.Attr[i].Name, false) + } + d.pushElement(t1.Name) + t = t1 + + case EndElement: + d.translate(&t1.Name, true) + if !d.popElement(&t1) { + return nil, d.err + } + t = t1 + } + return +} + +const xmlURL = "http://www.w3.org/XML/1998/namespace" + +// Apply name space translation to name n. +// The default name space (for Space=="") +// applies only to element names, not to attribute names. +func (d *Decoder) translate(n *Name, isElementName bool) { + switch { + case n.Space == "xmlns": + return + case n.Space == "" && !isElementName: + return + case n.Space == "xml": + n.Space = xmlURL + case n.Space == "" && n.Local == "xmlns": + return + } + if v, ok := d.ns[n.Space]; ok { + n.Space = v + } else if n.Space == "" { + n.Space = d.DefaultSpace + } +} + +func (d *Decoder) switchToReader(r io.Reader) { + // Get efficient byte at a time reader. + // Assume that if reader has its own + // ReadByte, it's efficient enough. + // Otherwise, use bufio. + if rb, ok := r.(io.ByteReader); ok { + d.r = rb + } else { + d.r = bufio.NewReader(r) + } +} + +// Parsing state - stack holds old name space translations +// and the current set of open elements. The translations to pop when +// ending a given tag are *below* it on the stack, which is +// more work but forced on us by XML. +type stack struct { + next *stack + kind int + name Name + ok bool +} + +const ( + stkStart = iota + stkNs + stkEOF +) + +func (d *Decoder) push(kind int) *stack { + s := d.free + if s != nil { + d.free = s.next + } else { + s = new(stack) + } + s.next = d.stk + s.kind = kind + d.stk = s + return s +} + +func (d *Decoder) pop() *stack { + s := d.stk + if s != nil { + d.stk = s.next + s.next = d.free + d.free = s + } + return s +} + +// Record that after the current element is finished +// (that element is already pushed on the stack) +// Token should return EOF until popEOF is called. +func (d *Decoder) pushEOF() { + // Walk down stack to find Start. + // It might not be the top, because there might be stkNs + // entries above it. + start := d.stk + for start.kind != stkStart { + start = start.next + } + // The stkNs entries below a start are associated with that + // element too; skip over them. + for start.next != nil && start.next.kind == stkNs { + start = start.next + } + s := d.free + if s != nil { + d.free = s.next + } else { + s = new(stack) + } + s.kind = stkEOF + s.next = start.next + start.next = s +} + +// Undo a pushEOF. +// The element must have been finished, so the EOF should be at the top of the stack. +func (d *Decoder) popEOF() bool { + if d.stk == nil || d.stk.kind != stkEOF { + return false + } + d.pop() + return true +} + +// Record that we are starting an element with the given name. +func (d *Decoder) pushElement(name Name) { + s := d.push(stkStart) + s.name = name +} + +// Record that we are changing the value of ns[local]. +// The old value is url, ok. +func (d *Decoder) pushNs(local string, url string, ok bool) { + s := d.push(stkNs) + s.name.Local = local + s.name.Space = url + s.ok = ok +} + +// Creates a SyntaxError with the current line number. +func (d *Decoder) syntaxError(msg string) error { + return &SyntaxError{Msg: msg, Line: d.line} +} + +// Record that we are ending an element with the given name. +// The name must match the record at the top of the stack, +// which must be a pushElement record. +// After popping the element, apply any undo records from +// the stack to restore the name translations that existed +// before we saw this element. +func (d *Decoder) popElement(t *EndElement) bool { + s := d.pop() + name := t.Name + switch { + case s == nil || s.kind != stkStart: + d.err = d.syntaxError("unexpected end element ") + return false + case s.name.Local != name.Local: + if !d.Strict { + d.needClose = true + d.toClose = t.Name + t.Name = s.name + return true + } + d.err = d.syntaxError("element <" + s.name.Local + "> closed by ") + return false + case s.name.Space != name.Space: + d.err = d.syntaxError("element <" + s.name.Local + "> in space " + s.name.Space + + "closed by in space " + name.Space) + return false + } + + // Pop stack until a Start or EOF is on the top, undoing the + // translations that were associated with the element we just closed. + for d.stk != nil && d.stk.kind != stkStart && d.stk.kind != stkEOF { + s := d.pop() + if s.ok { + d.ns[s.name.Local] = s.name.Space + } else { + delete(d.ns, s.name.Local) + } + } + + return true +} + +// If the top element on the stack is autoclosing and +// t is not the end tag, invent the end tag. +func (d *Decoder) autoClose(t Token) (Token, bool) { + if d.stk == nil || d.stk.kind != stkStart { + return nil, false + } + name := strings.ToLower(d.stk.name.Local) + for _, s := range d.AutoClose { + if strings.ToLower(s) == name { + // This one should be auto closed if t doesn't close it. + et, ok := t.(EndElement) + if !ok || et.Name.Local != name { + return EndElement{d.stk.name}, true + } + break + } + } + return nil, false +} + +var errRawToken = errors.New("xml: cannot use RawToken from UnmarshalXML method") + +// RawToken is like Token but does not verify that +// start and end elements match and does not translate +// name space prefixes to their corresponding URLs. +func (d *Decoder) RawToken() (Token, error) { + if d.unmarshalDepth > 0 { + return nil, errRawToken + } + return d.rawToken() +} + +func (d *Decoder) rawToken() (Token, error) { + if d.err != nil { + return nil, d.err + } + if d.needClose { + // The last element we read was self-closing and + // we returned just the StartElement half. + // Return the EndElement half now. + d.needClose = false + return EndElement{d.toClose}, nil + } + + b, ok := d.getc() + if !ok { + return nil, d.err + } + + if b != '<' { + // Text section. + d.ungetc(b) + data := d.text(-1, false) + if data == nil { + return nil, d.err + } + return CharData(data), nil + } + + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + switch b { + case '/': + // ' { + d.err = d.syntaxError("invalid characters between ") + return nil, d.err + } + return EndElement{name}, nil + + case '?': + // ' { + break + } + b0 = b + } + data := d.buf.Bytes() + data = data[0 : len(data)-2] // chop ?> + + if target == "xml" { + content := string(data) + ver := procInst("version", content) + if ver != "" && ver != "1.0" { + d.err = fmt.Errorf("xml: unsupported version %q; only version 1.0 is supported", ver) + return nil, d.err + } + enc := procInst("encoding", content) + if enc != "" && enc != "utf-8" && enc != "UTF-8" { + if d.CharsetReader == nil { + d.err = fmt.Errorf("xml: encoding %q declared but Decoder.CharsetReader is nil", enc) + return nil, d.err + } + newr, err := d.CharsetReader(enc, d.r.(io.Reader)) + if err != nil { + d.err = fmt.Errorf("xml: opening charset %q: %v", enc, err) + return nil, d.err + } + if newr == nil { + panic("CharsetReader returned a nil Reader for charset " + enc) + } + d.switchToReader(newr) + } + } + return ProcInst{target, data}, nil + + case '!': + // ' { + break + } + b0, b1 = b1, b + } + data := d.buf.Bytes() + data = data[0 : len(data)-3] // chop --> + return Comment(data), nil + + case '[': // . + data := d.text(-1, true) + if data == nil { + return nil, d.err + } + return CharData(data), nil + } + + // Probably a directive: , , etc. + // We don't care, but accumulate for caller. Quoted angle + // brackets do not count for nesting. + d.buf.Reset() + d.buf.WriteByte(b) + inquote := uint8(0) + depth := 0 + for { + if b, ok = d.mustgetc(); !ok { + return nil, d.err + } + if inquote == 0 && b == '>' && depth == 0 { + break + } + HandleB: + d.buf.WriteByte(b) + switch { + case b == inquote: + inquote = 0 + + case inquote != 0: + // in quotes, no special action + + case b == '\'' || b == '"': + inquote = b + + case b == '>' && inquote == 0: + depth-- + + case b == '<' && inquote == 0: + // Look for ` + +var testEntity = map[string]string{"何": "What", "is-it": "is it?"} + +var rawTokens = []Token{ + CharData("\n"), + ProcInst{"xml", []byte(`version="1.0" encoding="UTF-8"`)}, + CharData("\n"), + Directive(`DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" + "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"`), + CharData("\n"), + StartElement{Name{"", "body"}, []Attr{{Name{"xmlns", "foo"}, "ns1"}, {Name{"", "xmlns"}, "ns2"}, {Name{"xmlns", "tag"}, "ns3"}}}, + CharData("\n "), + StartElement{Name{"", "hello"}, []Attr{{Name{"", "lang"}, "en"}}}, + CharData("World <>'\" 白鵬翔"), + EndElement{Name{"", "hello"}}, + CharData("\n "), + StartElement{Name{"", "query"}, []Attr{}}, + CharData("What is it?"), + EndElement{Name{"", "query"}}, + CharData("\n "), + StartElement{Name{"", "goodbye"}, []Attr{}}, + EndElement{Name{"", "goodbye"}}, + CharData("\n "), + StartElement{Name{"", "outer"}, []Attr{{Name{"foo", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}}, + CharData("\n "), + StartElement{Name{"", "inner"}, []Attr{}}, + EndElement{Name{"", "inner"}}, + CharData("\n "), + EndElement{Name{"", "outer"}}, + CharData("\n "), + StartElement{Name{"tag", "name"}, []Attr{}}, + CharData("\n "), + CharData("Some text here."), + CharData("\n "), + EndElement{Name{"tag", "name"}}, + CharData("\n"), + EndElement{Name{"", "body"}}, + Comment(" missing final newline "), +} + +var cookedTokens = []Token{ + CharData("\n"), + ProcInst{"xml", []byte(`version="1.0" encoding="UTF-8"`)}, + CharData("\n"), + Directive(`DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" + "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"`), + CharData("\n"), + StartElement{Name{"ns2", "body"}, []Attr{{Name{"xmlns", "foo"}, "ns1"}, {Name{"", "xmlns"}, "ns2"}, {Name{"xmlns", "tag"}, "ns3"}}}, + CharData("\n "), + StartElement{Name{"ns2", "hello"}, []Attr{{Name{"", "lang"}, "en"}}}, + CharData("World <>'\" 白鵬翔"), + EndElement{Name{"ns2", "hello"}}, + CharData("\n "), + StartElement{Name{"ns2", "query"}, []Attr{}}, + CharData("What is it?"), + EndElement{Name{"ns2", "query"}}, + CharData("\n "), + StartElement{Name{"ns2", "goodbye"}, []Attr{}}, + EndElement{Name{"ns2", "goodbye"}}, + CharData("\n "), + StartElement{Name{"ns2", "outer"}, []Attr{{Name{"ns1", "attr"}, "value"}, {Name{"xmlns", "tag"}, "ns4"}}}, + CharData("\n "), + StartElement{Name{"ns2", "inner"}, []Attr{}}, + EndElement{Name{"ns2", "inner"}}, + CharData("\n "), + EndElement{Name{"ns2", "outer"}}, + CharData("\n "), + StartElement{Name{"ns3", "name"}, []Attr{}}, + CharData("\n "), + CharData("Some text here."), + CharData("\n "), + EndElement{Name{"ns3", "name"}}, + CharData("\n"), + EndElement{Name{"ns2", "body"}}, + Comment(" missing final newline "), +} + +const testInputAltEncoding = ` + +VALUE` + +var rawTokensAltEncoding = []Token{ + CharData("\n"), + ProcInst{"xml", []byte(`version="1.0" encoding="x-testing-uppercase"`)}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("value"), + EndElement{Name{"", "tag"}}, +} + +var xmlInput = []string{ + // unexpected EOF cases + "<", + "", + "", + "", + // "", // let the Token() caller handle + "", + "", + "", + "", + " c;", + "", + "", + "", + // "", // let the Token() caller handle + "", + "", + "cdata]]>", +} + +func TestRawToken(t *testing.T) { + d := NewDecoder(strings.NewReader(testInput)) + d.Entity = testEntity + testRawToken(t, d, testInput, rawTokens) +} + +const nonStrictInput = ` +non&entity +&unknown;entity +{ +&#zzz; +&なまえ3; +<-gt; +&; +&0a; +` + +var nonStringEntity = map[string]string{"": "oops!", "0a": "oops!"} + +var nonStrictTokens = []Token{ + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("non&entity"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&unknown;entity"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("{"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&#zzz;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&なまえ3;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("<-gt;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), + StartElement{Name{"", "tag"}, []Attr{}}, + CharData("&0a;"), + EndElement{Name{"", "tag"}}, + CharData("\n"), +} + +func TestNonStrictRawToken(t *testing.T) { + d := NewDecoder(strings.NewReader(nonStrictInput)) + d.Strict = false + testRawToken(t, d, nonStrictInput, nonStrictTokens) +} + +type downCaser struct { + t *testing.T + r io.ByteReader +} + +func (d *downCaser) ReadByte() (c byte, err error) { + c, err = d.r.ReadByte() + if c >= 'A' && c <= 'Z' { + c += 'a' - 'A' + } + return +} + +func (d *downCaser) Read(p []byte) (int, error) { + d.t.Fatalf("unexpected Read call on downCaser reader") + panic("unreachable") +} + +func TestRawTokenAltEncoding(t *testing.T) { + d := NewDecoder(strings.NewReader(testInputAltEncoding)) + d.CharsetReader = func(charset string, input io.Reader) (io.Reader, error) { + if charset != "x-testing-uppercase" { + t.Fatalf("unexpected charset %q", charset) + } + return &downCaser{t, input.(io.ByteReader)}, nil + } + testRawToken(t, d, testInputAltEncoding, rawTokensAltEncoding) +} + +func TestRawTokenAltEncodingNoConverter(t *testing.T) { + d := NewDecoder(strings.NewReader(testInputAltEncoding)) + token, err := d.RawToken() + if token == nil { + t.Fatalf("expected a token on first RawToken call") + } + if err != nil { + t.Fatal(err) + } + token, err = d.RawToken() + if token != nil { + t.Errorf("expected a nil token; got %#v", token) + } + if err == nil { + t.Fatalf("expected an error on second RawToken call") + } + const encoding = "x-testing-uppercase" + if !strings.Contains(err.Error(), encoding) { + t.Errorf("expected error to contain %q; got error: %v", + encoding, err) + } +} + +func testRawToken(t *testing.T, d *Decoder, raw string, rawTokens []Token) { + lastEnd := int64(0) + for i, want := range rawTokens { + start := d.InputOffset() + have, err := d.RawToken() + end := d.InputOffset() + if err != nil { + t.Fatalf("token %d: unexpected error: %s", i, err) + } + if !reflect.DeepEqual(have, want) { + var shave, swant string + if _, ok := have.(CharData); ok { + shave = fmt.Sprintf("CharData(%q)", have) + } else { + shave = fmt.Sprintf("%#v", have) + } + if _, ok := want.(CharData); ok { + swant = fmt.Sprintf("CharData(%q)", want) + } else { + swant = fmt.Sprintf("%#v", want) + } + t.Errorf("token %d = %s, want %s", i, shave, swant) + } + + // Check that InputOffset returned actual token. + switch { + case start < lastEnd: + t.Errorf("token %d: position [%d,%d) for %T is before previous token", i, start, end, have) + case start >= end: + // Special case: EndElement can be synthesized. + if start == end && end == lastEnd { + break + } + t.Errorf("token %d: position [%d,%d) for %T is empty", i, start, end, have) + case end > int64(len(raw)): + t.Errorf("token %d: position [%d,%d) for %T extends beyond input", i, start, end, have) + default: + text := raw[start:end] + if strings.ContainsAny(text, "<>") && (!strings.HasPrefix(text, "<") || !strings.HasSuffix(text, ">")) { + t.Errorf("token %d: misaligned raw token %#q for %T", i, text, have) + } + } + lastEnd = end + } +} + +// Ensure that directives (specifically !DOCTYPE) include the complete +// text of any nested directives, noting that < and > do not change +// nesting depth if they are in single or double quotes. + +var nestedDirectivesInput = ` +]> +">]> +]> +'>]> +]> +'>]> +]> +` + +var nestedDirectivesTokens = []Token{ + CharData("\n"), + Directive(`DOCTYPE []`), + CharData("\n"), + Directive(`DOCTYPE [">]`), + CharData("\n"), + Directive(`DOCTYPE []`), + CharData("\n"), + Directive(`DOCTYPE ['>]`), + CharData("\n"), + Directive(`DOCTYPE []`), + CharData("\n"), + Directive(`DOCTYPE ['>]`), + CharData("\n"), + Directive(`DOCTYPE []`), + CharData("\n"), +} + +func TestNestedDirectives(t *testing.T) { + d := NewDecoder(strings.NewReader(nestedDirectivesInput)) + + for i, want := range nestedDirectivesTokens { + have, err := d.Token() + if err != nil { + t.Fatalf("token %d: unexpected error: %s", i, err) + } + if !reflect.DeepEqual(have, want) { + t.Errorf("token %d = %#v want %#v", i, have, want) + } + } +} + +func TestToken(t *testing.T) { + d := NewDecoder(strings.NewReader(testInput)) + d.Entity = testEntity + + for i, want := range cookedTokens { + have, err := d.Token() + if err != nil { + t.Fatalf("token %d: unexpected error: %s", i, err) + } + if !reflect.DeepEqual(have, want) { + t.Errorf("token %d = %#v want %#v", i, have, want) + } + } +} + +func TestSyntax(t *testing.T) { + for i := range xmlInput { + d := NewDecoder(strings.NewReader(xmlInput[i])) + var err error + for _, err = d.Token(); err == nil; _, err = d.Token() { + } + if _, ok := err.(*SyntaxError); !ok { + t.Fatalf(`xmlInput "%s": expected SyntaxError not received`, xmlInput[i]) + } + } +} + +type allScalars struct { + True1 bool + True2 bool + False1 bool + False2 bool + Int int + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Uint int + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Uintptr uintptr + Float32 float32 + Float64 float64 + String string + PtrString *string +} + +var all = allScalars{ + True1: true, + True2: true, + False1: false, + False2: false, + Int: 1, + Int8: -2, + Int16: 3, + Int32: -4, + Int64: 5, + Uint: 6, + Uint8: 7, + Uint16: 8, + Uint32: 9, + Uint64: 10, + Uintptr: 11, + Float32: 13.0, + Float64: 14.0, + String: "15", + PtrString: &sixteen, +} + +var sixteen = "16" + +const testScalarsInput = ` + true + 1 + false + 0 + 1 + -2 + 3 + -4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12.0 + 13.0 + 14.0 + 15 + 16 +` + +func TestAllScalars(t *testing.T) { + var a allScalars + err := Unmarshal([]byte(testScalarsInput), &a) + + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(a, all) { + t.Errorf("have %+v want %+v", a, all) + } +} + +type item struct { + Field_a string +} + +func TestIssue569(t *testing.T) { + data := `abcd` + var i item + err := Unmarshal([]byte(data), &i) + + if err != nil || i.Field_a != "abcd" { + t.Fatal("Expecting abcd") + } +} + +func TestUnquotedAttrs(t *testing.T) { + data := "" + d := NewDecoder(strings.NewReader(data)) + d.Strict = false + token, err := d.Token() + if _, ok := err.(*SyntaxError); ok { + t.Errorf("Unexpected error: %v", err) + } + if token.(StartElement).Name.Local != "tag" { + t.Errorf("Unexpected tag name: %v", token.(StartElement).Name.Local) + } + attr := token.(StartElement).Attr[0] + if attr.Value != "azAZ09:-_" { + t.Errorf("Unexpected attribute value: %v", attr.Value) + } + if attr.Name.Local != "attr" { + t.Errorf("Unexpected attribute name: %v", attr.Name.Local) + } +} + +func TestValuelessAttrs(t *testing.T) { + tests := [][3]string{ + {"

", "p", "nowrap"}, + {"

", "p", "nowrap"}, + {"", "input", "checked"}, + {"", "input", "checked"}, + } + for _, test := range tests { + d := NewDecoder(strings.NewReader(test[0])) + d.Strict = false + token, err := d.Token() + if _, ok := err.(*SyntaxError); ok { + t.Errorf("Unexpected error: %v", err) + } + if token.(StartElement).Name.Local != test[1] { + t.Errorf("Unexpected tag name: %v", token.(StartElement).Name.Local) + } + attr := token.(StartElement).Attr[0] + if attr.Value != test[2] { + t.Errorf("Unexpected attribute value: %v", attr.Value) + } + if attr.Name.Local != test[2] { + t.Errorf("Unexpected attribute name: %v", attr.Name.Local) + } + } +} + +func TestCopyTokenCharData(t *testing.T) { + data := []byte("same data") + var tok1 Token = CharData(data) + tok2 := CopyToken(tok1) + if !reflect.DeepEqual(tok1, tok2) { + t.Error("CopyToken(CharData) != CharData") + } + data[1] = 'o' + if reflect.DeepEqual(tok1, tok2) { + t.Error("CopyToken(CharData) uses same buffer.") + } +} + +func TestCopyTokenStartElement(t *testing.T) { + elt := StartElement{Name{"", "hello"}, []Attr{{Name{"", "lang"}, "en"}}} + var tok1 Token = elt + tok2 := CopyToken(tok1) + if tok1.(StartElement).Attr[0].Value != "en" { + t.Error("CopyToken overwrote Attr[0]") + } + if !reflect.DeepEqual(tok1, tok2) { + t.Error("CopyToken(StartElement) != StartElement") + } + tok1.(StartElement).Attr[0] = Attr{Name{"", "lang"}, "de"} + if reflect.DeepEqual(tok1, tok2) { + t.Error("CopyToken(CharData) uses same buffer.") + } +} + +func TestSyntaxErrorLineNum(t *testing.T) { + testInput := "

Foo

\n\n

Bar\n" + d := NewDecoder(strings.NewReader(testInput)) + var err error + for _, err = d.Token(); err == nil; _, err = d.Token() { + } + synerr, ok := err.(*SyntaxError) + if !ok { + t.Error("Expected SyntaxError.") + } + if synerr.Line != 3 { + t.Error("SyntaxError didn't have correct line number.") + } +} + +func TestTrailingRawToken(t *testing.T) { + input := ` ` + d := NewDecoder(strings.NewReader(input)) + var err error + for _, err = d.RawToken(); err == nil; _, err = d.RawToken() { + } + if err != io.EOF { + t.Fatalf("d.RawToken() = _, %v, want _, io.EOF", err) + } +} + +func TestTrailingToken(t *testing.T) { + input := ` ` + d := NewDecoder(strings.NewReader(input)) + var err error + for _, err = d.Token(); err == nil; _, err = d.Token() { + } + if err != io.EOF { + t.Fatalf("d.Token() = _, %v, want _, io.EOF", err) + } +} + +func TestEntityInsideCDATA(t *testing.T) { + input := `` + d := NewDecoder(strings.NewReader(input)) + var err error + for _, err = d.Token(); err == nil; _, err = d.Token() { + } + if err != io.EOF { + t.Fatalf("d.Token() = _, %v, want _, io.EOF", err) + } +} + +var characterTests = []struct { + in string + err string +}{ + {"\x12", "illegal character code U+0012"}, + {"\x0b", "illegal character code U+000B"}, + {"\xef\xbf\xbe", "illegal character code U+FFFE"}, + {"\r\n\x07", "illegal character code U+0007"}, + {"what's up", "expected attribute name in element"}, + {"&abc\x01;", "invalid character entity &abc (no semicolon)"}, + {"&\x01;", "invalid character entity & (no semicolon)"}, + {"&\xef\xbf\xbe;", "invalid character entity &\uFFFE;"}, + {"&hello;", "invalid character entity &hello;"}, +} + +func TestDisallowedCharacters(t *testing.T) { + + for i, tt := range characterTests { + d := NewDecoder(strings.NewReader(tt.in)) + var err error + + for err == nil { + _, err = d.Token() + } + synerr, ok := err.(*SyntaxError) + if !ok { + t.Fatalf("input %d d.Token() = _, %v, want _, *SyntaxError", i, err) + } + if synerr.Msg != tt.err { + t.Fatalf("input %d synerr.Msg wrong: want %q, got %q", i, tt.err, synerr.Msg) + } + } +} + +type procInstEncodingTest struct { + expect, got string +} + +var procInstTests = []struct { + input string + expect [2]string +}{ + {`version="1.0" encoding="utf-8"`, [2]string{"1.0", "utf-8"}}, + {`version="1.0" encoding='utf-8'`, [2]string{"1.0", "utf-8"}}, + {`version="1.0" encoding='utf-8' `, [2]string{"1.0", "utf-8"}}, + {`version="1.0" encoding=utf-8`, [2]string{"1.0", ""}}, + {`encoding="FOO" `, [2]string{"", "FOO"}}, +} + +func TestProcInstEncoding(t *testing.T) { + for _, test := range procInstTests { + if got := procInst("version", test.input); got != test.expect[0] { + t.Errorf("procInst(version, %q) = %q; want %q", test.input, got, test.expect[0]) + } + if got := procInst("encoding", test.input); got != test.expect[1] { + t.Errorf("procInst(encoding, %q) = %q; want %q", test.input, got, test.expect[1]) + } + } +} + +// Ensure that directives with comments include the complete +// text of any nested directives. + +var directivesWithCommentsInput = ` +]> +]> + --> --> []> +` + +var directivesWithCommentsTokens = []Token{ + CharData("\n"), + Directive(`DOCTYPE []`), + CharData("\n"), + Directive(`DOCTYPE []`), + CharData("\n"), + Directive(`DOCTYPE []`), + CharData("\n"), +} + +func TestDirectivesWithComments(t *testing.T) { + d := NewDecoder(strings.NewReader(directivesWithCommentsInput)) + + for i, want := range directivesWithCommentsTokens { + have, err := d.Token() + if err != nil { + t.Fatalf("token %d: unexpected error: %s", i, err) + } + if !reflect.DeepEqual(have, want) { + t.Errorf("token %d = %#v want %#v", i, have, want) + } + } +} + +// Writer whose Write method always returns an error. +type errWriter struct{} + +func (errWriter) Write(p []byte) (n int, err error) { return 0, fmt.Errorf("unwritable") } + +func TestEscapeTextIOErrors(t *testing.T) { + expectErr := "unwritable" + err := EscapeText(errWriter{}, []byte{'A'}) + + if err == nil || err.Error() != expectErr { + t.Errorf("have %v, want %v", err, expectErr) + } +} + +func TestEscapeTextInvalidChar(t *testing.T) { + input := []byte("A \x00 terminated string.") + expected := "A \uFFFD terminated string." + + buff := new(bytes.Buffer) + if err := EscapeText(buff, input); err != nil { + t.Fatalf("have %v, want nil", err) + } + text := buff.String() + + if text != expected { + t.Errorf("have %v, want %v", text, expected) + } +} + +func TestIssue5880(t *testing.T) { + type T []byte + data, err := Marshal(T{192, 168, 0, 1}) + if err != nil { + t.Errorf("Marshal error: %v", err) + } + if !utf8.Valid(data) { + t.Errorf("Marshal generated invalid UTF-8: %x", data) + } +} diff --git a/server/webdav/litmus_test_server.go b/server/webdav/litmus_test_server.go new file mode 100644 index 0000000000000000000000000000000000000000..6334d7e23357d1f6a78d4634f0a50596ed0f5670 --- /dev/null +++ b/server/webdav/litmus_test_server.go @@ -0,0 +1,95 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build ignore +// +build ignore + +/* +This program is a server for the WebDAV 'litmus' compliance test at +http://www.webdav.org/neon/litmus/ +To run the test: + +go run litmus_test_server.go + +and separately, from the downloaded litmus-xxx directory: + +make URL=http://localhost:9999/ check +*/ +package main + +import ( + "flag" + "fmt" + "log" + "net/http" + "net/url" + + "golang.org/x/net/webdav" +) + +var port = flag.Int("port", 9999, "server port") + +func main() { + flag.Parse() + log.SetFlags(0) + h := &webdav.Handler{ + FileSystem: webdav.NewMemFS(), + LockSystem: webdav.NewMemLS(), + Logger: func(r *http.Request, err error) { + litmus := r.Header.Get("X-Litmus") + if len(litmus) > 19 { + litmus = litmus[:16] + "..." + } + + switch r.Method { + case "COPY", "MOVE": + dst := "" + if u, err := url.Parse(r.Header.Get("Destination")); err == nil { + dst = u.Path + } + o := r.Header.Get("Overwrite") + log.Printf("%-20s%-10s%-30s%-30so=%-2s%v", litmus, r.Method, r.URL.Path, dst, o, err) + default: + log.Printf("%-20s%-10s%-30s%v", litmus, r.Method, r.URL.Path, err) + } + }, + } + + // The next line would normally be: + // http.Handle("/", h) + // but we wrap that HTTP handler h to cater for a special case. + // + // The propfind_invalid2 litmus test case expects an empty namespace prefix + // declaration to be an error. The FAQ in the webdav litmus test says: + // + // "What does the "propfind_invalid2" test check for?... + // + // If a request was sent with an XML body which included an empty namespace + // prefix declaration (xmlns:ns1=""), then the server must reject that with + // a "400 Bad Request" response, as it is invalid according to the XML + // Namespace specification." + // + // On the other hand, the Go standard library's encoding/xml package + // accepts an empty xmlns namespace, as per the discussion at + // https://github.com/golang/go/issues/8068 + // + // Empty namespaces seem disallowed in the second (2006) edition of the XML + // standard, but allowed in a later edition. The grammar differs between + // http://www.w3.org/TR/2006/REC-xml-names-20060816/#ns-decl and + // http://www.w3.org/TR/REC-xml-names/#dt-prefix + // + // Thus, we assume that the propfind_invalid2 test is obsolete, and + // hard-code the 400 Bad Request response that the test expects. + http.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Litmus") == "props: 3 (propfind_invalid2)" { + http.Error(w, "400 Bad Request", http.StatusBadRequest) + return + } + h.ServeHTTP(w, r) + })) + + addr := fmt.Sprintf(":%d", *port) + log.Printf("Serving %v", addr) + log.Fatal(http.ListenAndServe(addr, nil)) +} diff --git a/server/webdav/lock.go b/server/webdav/lock.go new file mode 100644 index 0000000000000000000000000000000000000000..344ac5ceaf140b79ef3af6ecfca0139d4227f0b1 --- /dev/null +++ b/server/webdav/lock.go @@ -0,0 +1,445 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package webdav + +import ( + "container/heap" + "errors" + "strconv" + "strings" + "sync" + "time" +) + +var ( + // ErrConfirmationFailed is returned by a LockSystem's Confirm method. + ErrConfirmationFailed = errors.New("webdav: confirmation failed") + // ErrForbidden is returned by a LockSystem's Unlock method. + ErrForbidden = errors.New("webdav: forbidden") + // ErrLocked is returned by a LockSystem's Create, Refresh and Unlock methods. + ErrLocked = errors.New("webdav: locked") + // ErrNoSuchLock is returned by a LockSystem's Refresh and Unlock methods. + ErrNoSuchLock = errors.New("webdav: no such lock") +) + +// Condition can match a WebDAV resource, based on a token or ETag. +// Exactly one of Token and ETag should be non-empty. +type Condition struct { + Not bool + Token string + ETag string +} + +// LockSystem manages access to a collection of named resources. The elements +// in a lock name are separated by slash ('/', U+002F) characters, regardless +// of host operating system convention. +type LockSystem interface { + // Confirm confirms that the caller can claim all of the locks specified by + // the given conditions, and that holding the union of all of those locks + // gives exclusive access to all of the named resources. Up to two resources + // can be named. Empty names are ignored. + // + // Exactly one of release and err will be non-nil. If release is non-nil, + // all of the requested locks are held until release is called. Calling + // release does not unlock the lock, in the WebDAV UNLOCK sense, but once + // Confirm has confirmed that a lock claim is valid, that lock cannot be + // Confirmed again until it has been released. + // + // If Confirm returns ErrConfirmationFailed then the Handler will continue + // to try any other set of locks presented (a WebDAV HTTP request can + // present more than one set of locks). If it returns any other non-nil + // error, the Handler will write a "500 Internal Server Error" HTTP status. + Confirm(now time.Time, name0, name1 string, conditions ...Condition) (release func(), err error) + + // Create creates a lock with the given depth, duration, owner and root + // (name). The depth will either be negative (meaning infinite) or zero. + // + // If Create returns ErrLocked then the Handler will write a "423 Locked" + // HTTP status. If it returns any other non-nil error, the Handler will + // write a "500 Internal Server Error" HTTP status. + // + // See http://www.webdav.org/specs/rfc4918.html#rfc.section.9.10.6 for + // when to use each error. + // + // The token returned identifies the created lock. It should be an absolute + // URI as defined by RFC 3986, Section 4.3. In particular, it should not + // contain whitespace. + Create(now time.Time, details LockDetails) (token string, err error) + + // Refresh refreshes the lock with the given token. + // + // If Refresh returns ErrLocked then the Handler will write a "423 Locked" + // HTTP Status. If Refresh returns ErrNoSuchLock then the Handler will write + // a "412 Precondition Failed" HTTP Status. If it returns any other non-nil + // error, the Handler will write a "500 Internal Server Error" HTTP status. + // + // See http://www.webdav.org/specs/rfc4918.html#rfc.section.9.10.6 for + // when to use each error. + Refresh(now time.Time, token string, duration time.Duration) (LockDetails, error) + + // Unlock unlocks the lock with the given token. + // + // If Unlock returns ErrForbidden then the Handler will write a "403 + // Forbidden" HTTP Status. If Unlock returns ErrLocked then the Handler + // will write a "423 Locked" HTTP status. If Unlock returns ErrNoSuchLock + // then the Handler will write a "409 Conflict" HTTP Status. If it returns + // any other non-nil error, the Handler will write a "500 Internal Server + // Error" HTTP status. + // + // See http://www.webdav.org/specs/rfc4918.html#rfc.section.9.11.1 for + // when to use each error. + Unlock(now time.Time, token string) error +} + +// LockDetails are a lock's metadata. +type LockDetails struct { + // Root is the root resource name being locked. For a zero-depth lock, the + // root is the only resource being locked. + Root string + // Duration is the lock timeout. A negative duration means infinite. + Duration time.Duration + // OwnerXML is the verbatim XML given in a LOCK HTTP request. + // + // TODO: does the "verbatim" nature play well with XML namespaces? + // Does the OwnerXML field need to have more structure? See + // https://codereview.appspot.com/175140043/#msg2 + OwnerXML string + // ZeroDepth is whether the lock has zero depth. If it does not have zero + // depth, it has infinite depth. + ZeroDepth bool +} + +// NewMemLS returns a new in-memory LockSystem. +func NewMemLS() LockSystem { + return &memLS{ + byName: make(map[string]*memLSNode), + byToken: make(map[string]*memLSNode), + gen: uint64(time.Now().Unix()), + } +} + +type memLS struct { + mu sync.Mutex + byName map[string]*memLSNode + byToken map[string]*memLSNode + gen uint64 + // byExpiry only contains those nodes whose LockDetails have a finite + // Duration and are yet to expire. + byExpiry byExpiry +} + +func (m *memLS) nextToken() string { + m.gen++ + return strconv.FormatUint(m.gen, 10) +} + +func (m *memLS) collectExpiredNodes(now time.Time) { + for len(m.byExpiry) > 0 { + if now.Before(m.byExpiry[0].expiry) { + break + } + m.remove(m.byExpiry[0]) + } +} + +func (m *memLS) Confirm(now time.Time, name0, name1 string, conditions ...Condition) (func(), error) { + m.mu.Lock() + defer m.mu.Unlock() + m.collectExpiredNodes(now) + + var n0, n1 *memLSNode + if name0 != "" { + if n0 = m.lookup(slashClean(name0), conditions...); n0 == nil { + return nil, ErrConfirmationFailed + } + } + if name1 != "" { + if n1 = m.lookup(slashClean(name1), conditions...); n1 == nil { + return nil, ErrConfirmationFailed + } + } + + // Don't hold the same node twice. + if n1 == n0 { + n1 = nil + } + + if n0 != nil { + m.hold(n0) + } + if n1 != nil { + m.hold(n1) + } + return func() { + m.mu.Lock() + defer m.mu.Unlock() + if n1 != nil { + m.unhold(n1) + } + if n0 != nil { + m.unhold(n0) + } + }, nil +} + +// lookup returns the node n that locks the named resource, provided that n +// matches at least one of the given conditions and that lock isn't held by +// another party. Otherwise, it returns nil. +// +// n may be a parent of the named resource, if n is an infinite depth lock. +func (m *memLS) lookup(name string, conditions ...Condition) (n *memLSNode) { + // TODO: support Condition.Not and Condition.ETag. + for _, c := range conditions { + n = m.byToken[c.Token] + if n == nil || n.held { + continue + } + if name == n.details.Root { + return n + } + if n.details.ZeroDepth { + continue + } + if n.details.Root == "/" || strings.HasPrefix(name, n.details.Root+"/") { + return n + } + } + return nil +} + +func (m *memLS) hold(n *memLSNode) { + if n.held { + panic("webdav: memLS inconsistent held state") + } + n.held = true + if n.details.Duration >= 0 && n.byExpiryIndex >= 0 { + heap.Remove(&m.byExpiry, n.byExpiryIndex) + } +} + +func (m *memLS) unhold(n *memLSNode) { + if !n.held { + panic("webdav: memLS inconsistent held state") + } + n.held = false + if n.details.Duration >= 0 { + heap.Push(&m.byExpiry, n) + } +} + +func (m *memLS) Create(now time.Time, details LockDetails) (string, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.collectExpiredNodes(now) + details.Root = slashClean(details.Root) + + if !m.canCreate(details.Root, details.ZeroDepth) { + return "", ErrLocked + } + n := m.create(details.Root) + n.token = m.nextToken() + m.byToken[n.token] = n + n.details = details + if n.details.Duration >= 0 { + n.expiry = now.Add(n.details.Duration) + heap.Push(&m.byExpiry, n) + } + return n.token, nil +} + +func (m *memLS) Refresh(now time.Time, token string, duration time.Duration) (LockDetails, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.collectExpiredNodes(now) + + n := m.byToken[token] + if n == nil { + return LockDetails{}, ErrNoSuchLock + } + if n.held { + return LockDetails{}, ErrLocked + } + if n.byExpiryIndex >= 0 { + heap.Remove(&m.byExpiry, n.byExpiryIndex) + } + n.details.Duration = duration + if n.details.Duration >= 0 { + n.expiry = now.Add(n.details.Duration) + heap.Push(&m.byExpiry, n) + } + return n.details, nil +} + +func (m *memLS) Unlock(now time.Time, token string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.collectExpiredNodes(now) + + n := m.byToken[token] + if n == nil { + return ErrNoSuchLock + } + if n.held { + return ErrLocked + } + m.remove(n) + return nil +} + +func (m *memLS) canCreate(name string, zeroDepth bool) bool { + return walkToRoot(name, func(name0 string, first bool) bool { + n := m.byName[name0] + if n == nil { + return true + } + if first { + if n.token != "" { + // The target node is already locked. + return false + } + if !zeroDepth { + // The requested lock depth is infinite, and the fact that n exists + // (n != nil) means that a descendent of the target node is locked. + return false + } + } else if n.token != "" && !n.details.ZeroDepth { + // An ancestor of the target node is locked with infinite depth. + return false + } + return true + }) +} + +func (m *memLS) create(name string) (ret *memLSNode) { + walkToRoot(name, func(name0 string, first bool) bool { + n := m.byName[name0] + if n == nil { + n = &memLSNode{ + details: LockDetails{ + Root: name0, + }, + byExpiryIndex: -1, + } + m.byName[name0] = n + } + n.refCount++ + if first { + ret = n + } + return true + }) + return ret +} + +func (m *memLS) remove(n *memLSNode) { + delete(m.byToken, n.token) + n.token = "" + walkToRoot(n.details.Root, func(name0 string, first bool) bool { + x := m.byName[name0] + x.refCount-- + if x.refCount == 0 { + delete(m.byName, name0) + } + return true + }) + if n.byExpiryIndex >= 0 { + heap.Remove(&m.byExpiry, n.byExpiryIndex) + } +} + +func walkToRoot(name string, f func(name0 string, first bool) bool) bool { + for first := true; ; first = false { + if !f(name, first) { + return false + } + if name == "/" { + break + } + name = name[:strings.LastIndex(name, "/")] + if name == "" { + name = "/" + } + } + return true +} + +type memLSNode struct { + // details are the lock metadata. Even if this node's name is not explicitly locked, + // details.Root will still equal the node's name. + details LockDetails + // token is the unique identifier for this node's lock. An empty token means that + // this node is not explicitly locked. + token string + // refCount is the number of self-or-descendent nodes that are explicitly locked. + refCount int + // expiry is when this node's lock expires. + expiry time.Time + // byExpiryIndex is the index of this node in memLS.byExpiry. It is -1 + // if this node does not expire, or has expired. + byExpiryIndex int + // held is whether this node's lock is actively held by a Confirm call. + held bool +} + +type byExpiry []*memLSNode + +func (b *byExpiry) Len() int { + return len(*b) +} + +func (b *byExpiry) Less(i, j int) bool { + return (*b)[i].expiry.Before((*b)[j].expiry) +} + +func (b *byExpiry) Swap(i, j int) { + (*b)[i], (*b)[j] = (*b)[j], (*b)[i] + (*b)[i].byExpiryIndex = i + (*b)[j].byExpiryIndex = j +} + +func (b *byExpiry) Push(x interface{}) { + n := x.(*memLSNode) + n.byExpiryIndex = len(*b) + *b = append(*b, n) +} + +func (b *byExpiry) Pop() interface{} { + i := len(*b) - 1 + n := (*b)[i] + (*b)[i] = nil + n.byExpiryIndex = -1 + *b = (*b)[:i] + return n +} + +const infiniteTimeout = -1 + +// parseTimeout parses the Timeout HTTP header, as per section 10.7. If s is +// empty, an infiniteTimeout is returned. +func parseTimeout(s string) (time.Duration, error) { + if s == "" { + return infiniteTimeout, nil + } + if i := strings.IndexByte(s, ','); i >= 0 { + s = s[:i] + } + s = strings.TrimSpace(s) + if s == "Infinite" { + return infiniteTimeout, nil + } + const pre = "Second-" + if !strings.HasPrefix(s, pre) { + return 0, errInvalidTimeout + } + s = s[len(pre):] + if s == "" || s[0] < '0' || '9' < s[0] { + return 0, errInvalidTimeout + } + n, err := strconv.ParseInt(s, 10, 64) + if err != nil || 1<<32-1 < n { + return 0, errInvalidTimeout + } + return time.Duration(n) * time.Second, nil +} diff --git a/server/webdav/lock_test.go b/server/webdav/lock_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e7fe97061b66c237e716619ca66d487388a2b91c --- /dev/null +++ b/server/webdav/lock_test.go @@ -0,0 +1,735 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package webdav + +import ( + "fmt" + "math/rand" + "path" + "reflect" + "sort" + "strconv" + "strings" + "testing" + "time" +) + +func TestWalkToRoot(t *testing.T) { + testCases := []struct { + name string + want []string + }{{ + "/a/b/c/d", + []string{ + "/a/b/c/d", + "/a/b/c", + "/a/b", + "/a", + "/", + }, + }, { + "/a", + []string{ + "/a", + "/", + }, + }, { + "/", + []string{ + "/", + }, + }} + + for _, tc := range testCases { + var got []string + if !walkToRoot(tc.name, func(name0 string, first bool) bool { + if first != (len(got) == 0) { + t.Errorf("name=%q: first=%t but len(got)==%d", tc.name, first, len(got)) + return false + } + got = append(got, name0) + return true + }) { + continue + } + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("name=%q:\ngot %q\nwant %q", tc.name, got, tc.want) + } + } +} + +var lockTestDurations = []time.Duration{ + infiniteTimeout, // infiniteTimeout means to never expire. + 0, // A zero duration means to expire immediately. + 100 * time.Hour, // A very large duration will not expire in these tests. +} + +// lockTestNames are the names of a set of mutually compatible locks. For each +// name fragment: +// - _ means no explicit lock. +// - i means an infinite-depth lock, +// - z means a zero-depth lock, +var lockTestNames = []string{ + "/_/_/_/_/z", + "/_/_/i", + "/_/z", + "/_/z/i", + "/_/z/z", + "/_/z/_/i", + "/_/z/_/z", + "/i", + "/z", + "/z/_/i", + "/z/_/z", +} + +func lockTestZeroDepth(name string) bool { + switch name[len(name)-1] { + case 'i': + return false + case 'z': + return true + } + panic(fmt.Sprintf("lock name %q did not end with 'i' or 'z'", name)) +} + +func TestMemLSCanCreate(t *testing.T) { + now := time.Unix(0, 0) + m := NewMemLS().(*memLS) + + for _, name := range lockTestNames { + _, err := m.Create(now, LockDetails{ + Root: name, + Duration: infiniteTimeout, + ZeroDepth: lockTestZeroDepth(name), + }) + if err != nil { + t.Fatalf("creating lock for %q: %v", name, err) + } + } + + wantCanCreate := func(name string, zeroDepth bool) bool { + for _, n := range lockTestNames { + switch { + case n == name: + // An existing lock has the same name as the proposed lock. + return false + case strings.HasPrefix(n, name): + // An existing lock would be a child of the proposed lock, + // which conflicts if the proposed lock has infinite depth. + if !zeroDepth { + return false + } + case strings.HasPrefix(name, n): + // An existing lock would be an ancestor of the proposed lock, + // which conflicts if the ancestor has infinite depth. + if n[len(n)-1] == 'i' { + return false + } + } + } + return true + } + + var check func(int, string) + check = func(recursion int, name string) { + for _, zeroDepth := range []bool{false, true} { + got := m.canCreate(name, zeroDepth) + want := wantCanCreate(name, zeroDepth) + if got != want { + t.Errorf("canCreate name=%q zeroDepth=%t: got %t, want %t", name, zeroDepth, got, want) + } + } + if recursion == 6 { + return + } + if name != "/" { + name += "/" + } + for _, c := range "_iz" { + check(recursion+1, name+string(c)) + } + } + check(0, "/") +} + +func TestMemLSLookup(t *testing.T) { + now := time.Unix(0, 0) + m := NewMemLS().(*memLS) + + badToken := m.nextToken() + t.Logf("badToken=%q", badToken) + + for _, name := range lockTestNames { + token, err := m.Create(now, LockDetails{ + Root: name, + Duration: infiniteTimeout, + ZeroDepth: lockTestZeroDepth(name), + }) + if err != nil { + t.Fatalf("creating lock for %q: %v", name, err) + } + t.Logf("%-15q -> node=%p token=%q", name, m.byName[name], token) + } + + baseNames := append([]string{"/a", "/b/c"}, lockTestNames...) + for _, baseName := range baseNames { + for _, suffix := range []string{"", "/0", "/1/2/3"} { + name := baseName + suffix + + goodToken := "" + base := m.byName[baseName] + if base != nil && (suffix == "" || !lockTestZeroDepth(baseName)) { + goodToken = base.token + } + + for _, token := range []string{badToken, goodToken} { + if token == "" { + continue + } + + got := m.lookup(name, Condition{Token: token}) + want := base + if token == badToken { + want = nil + } + if got != want { + t.Errorf("name=%-20qtoken=%q (bad=%t): got %p, want %p", + name, token, token == badToken, got, want) + } + } + } + } +} + +func TestMemLSConfirm(t *testing.T) { + now := time.Unix(0, 0) + m := NewMemLS().(*memLS) + alice, err := m.Create(now, LockDetails{ + Root: "/alice", + Duration: infiniteTimeout, + ZeroDepth: false, + }) + if err != nil { + t.Fatalf("Create: %v", err) + } + + tweedle, err := m.Create(now, LockDetails{ + Root: "/tweedle", + Duration: infiniteTimeout, + ZeroDepth: false, + }) + if err != nil { + t.Fatalf("Create: %v", err) + } + if err := m.consistent(); err != nil { + t.Fatalf("Create: inconsistent state: %v", err) + } + + // Test a mismatch between name and condition. + _, err = m.Confirm(now, "/tweedle/dee", "", Condition{Token: alice}) + if err != ErrConfirmationFailed { + t.Fatalf("Confirm (mismatch): got %v, want ErrConfirmationFailed", err) + } + if err := m.consistent(); err != nil { + t.Fatalf("Confirm (mismatch): inconsistent state: %v", err) + } + + // Test two names (that fall under the same lock) in the one Confirm call. + release, err := m.Confirm(now, "/tweedle/dee", "/tweedle/dum", Condition{Token: tweedle}) + if err != nil { + t.Fatalf("Confirm (twins): %v", err) + } + if err := m.consistent(); err != nil { + t.Fatalf("Confirm (twins): inconsistent state: %v", err) + } + release() + if err := m.consistent(); err != nil { + t.Fatalf("release (twins): inconsistent state: %v", err) + } + + // Test the same two names in overlapping Confirm / release calls. + releaseDee, err := m.Confirm(now, "/tweedle/dee", "", Condition{Token: tweedle}) + if err != nil { + t.Fatalf("Confirm (sequence #0): %v", err) + } + if err := m.consistent(); err != nil { + t.Fatalf("Confirm (sequence #0): inconsistent state: %v", err) + } + + _, err = m.Confirm(now, "/tweedle/dum", "", Condition{Token: tweedle}) + if err != ErrConfirmationFailed { + t.Fatalf("Confirm (sequence #1): got %v, want ErrConfirmationFailed", err) + } + if err := m.consistent(); err != nil { + t.Fatalf("Confirm (sequence #1): inconsistent state: %v", err) + } + + releaseDee() + if err := m.consistent(); err != nil { + t.Fatalf("release (sequence #2): inconsistent state: %v", err) + } + + releaseDum, err := m.Confirm(now, "/tweedle/dum", "", Condition{Token: tweedle}) + if err != nil { + t.Fatalf("Confirm (sequence #3): %v", err) + } + if err := m.consistent(); err != nil { + t.Fatalf("Confirm (sequence #3): inconsistent state: %v", err) + } + + // Test that you can't unlock a held lock. + err = m.Unlock(now, tweedle) + if err != ErrLocked { + t.Fatalf("Unlock (sequence #4): got %v, want ErrLocked", err) + } + + releaseDum() + if err := m.consistent(); err != nil { + t.Fatalf("release (sequence #5): inconsistent state: %v", err) + } + + err = m.Unlock(now, tweedle) + if err != nil { + t.Fatalf("Unlock (sequence #6): %v", err) + } + if err := m.consistent(); err != nil { + t.Fatalf("Unlock (sequence #6): inconsistent state: %v", err) + } +} + +func TestMemLSNonCanonicalRoot(t *testing.T) { + now := time.Unix(0, 0) + m := NewMemLS().(*memLS) + token, err := m.Create(now, LockDetails{ + Root: "/foo/./bar//", + Duration: 1 * time.Second, + }) + if err != nil { + t.Fatalf("Create: %v", err) + } + if err := m.consistent(); err != nil { + t.Fatalf("Create: inconsistent state: %v", err) + } + if err := m.Unlock(now, token); err != nil { + t.Fatalf("Unlock: %v", err) + } + if err := m.consistent(); err != nil { + t.Fatalf("Unlock: inconsistent state: %v", err) + } +} + +func TestMemLSExpiry(t *testing.T) { + m := NewMemLS().(*memLS) + testCases := []string{ + "setNow 0", + "create /a.5", + "want /a.5", + "create /c.6", + "want /a.5 /c.6", + "create /a/b.7", + "want /a.5 /a/b.7 /c.6", + "setNow 4", + "want /a.5 /a/b.7 /c.6", + "setNow 5", + "want /a/b.7 /c.6", + "setNow 6", + "want /a/b.7", + "setNow 7", + "want ", + "setNow 8", + "want ", + "create /a.12", + "create /b.13", + "create /c.15", + "create /a/d.16", + "want /a.12 /a/d.16 /b.13 /c.15", + "refresh /a.14", + "want /a.14 /a/d.16 /b.13 /c.15", + "setNow 12", + "want /a.14 /a/d.16 /b.13 /c.15", + "setNow 13", + "want /a.14 /a/d.16 /c.15", + "setNow 14", + "want /a/d.16 /c.15", + "refresh /a/d.20", + "refresh /c.20", + "want /a/d.20 /c.20", + "setNow 20", + "want ", + } + + tokens := map[string]string{} + zTime := time.Unix(0, 0) + now := zTime + for i, tc := range testCases { + j := strings.IndexByte(tc, ' ') + if j < 0 { + t.Fatalf("test case #%d %q: invalid command", i, tc) + } + op, arg := tc[:j], tc[j+1:] + switch op { + default: + t.Fatalf("test case #%d %q: invalid operation %q", i, tc, op) + + case "create", "refresh": + parts := strings.Split(arg, ".") + if len(parts) != 2 { + t.Fatalf("test case #%d %q: invalid create", i, tc) + } + root := parts[0] + d, err := strconv.Atoi(parts[1]) + if err != nil { + t.Fatalf("test case #%d %q: invalid duration", i, tc) + } + dur := time.Unix(0, 0).Add(time.Duration(d) * time.Second).Sub(now) + + switch op { + case "create": + token, err := m.Create(now, LockDetails{ + Root: root, + Duration: dur, + ZeroDepth: true, + }) + if err != nil { + t.Fatalf("test case #%d %q: Create: %v", i, tc, err) + } + tokens[root] = token + + case "refresh": + token := tokens[root] + if token == "" { + t.Fatalf("test case #%d %q: no token for %q", i, tc, root) + } + got, err := m.Refresh(now, token, dur) + if err != nil { + t.Fatalf("test case #%d %q: Refresh: %v", i, tc, err) + } + want := LockDetails{ + Root: root, + Duration: dur, + ZeroDepth: true, + } + if got != want { + t.Fatalf("test case #%d %q:\ngot %v\nwant %v", i, tc, got, want) + } + } + + case "setNow": + d, err := strconv.Atoi(arg) + if err != nil { + t.Fatalf("test case #%d %q: invalid duration", i, tc) + } + now = time.Unix(0, 0).Add(time.Duration(d) * time.Second) + + case "want": + m.mu.Lock() + m.collectExpiredNodes(now) + got := make([]string, 0, len(m.byToken)) + for _, n := range m.byToken { + got = append(got, fmt.Sprintf("%s.%d", + n.details.Root, n.expiry.Sub(zTime)/time.Second)) + } + m.mu.Unlock() + sort.Strings(got) + want := []string{} + if arg != "" { + want = strings.Split(arg, " ") + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("test case #%d %q:\ngot %q\nwant %q", i, tc, got, want) + } + } + + if err := m.consistent(); err != nil { + t.Fatalf("test case #%d %q: inconsistent state: %v", i, tc, err) + } + } +} + +func TestMemLS(t *testing.T) { + now := time.Unix(0, 0) + m := NewMemLS().(*memLS) + rng := rand.New(rand.NewSource(0)) + tokens := map[string]string{} + nConfirm, nCreate, nRefresh, nUnlock := 0, 0, 0, 0 + const N = 2000 + + for i := 0; i < N; i++ { + name := lockTestNames[rng.Intn(len(lockTestNames))] + duration := lockTestDurations[rng.Intn(len(lockTestDurations))] + confirmed, unlocked := false, false + + // If the name was already locked, we randomly confirm/release, refresh + // or unlock it. Otherwise, we create a lock. + token := tokens[name] + if token != "" { + switch rng.Intn(3) { + case 0: + confirmed = true + nConfirm++ + release, err := m.Confirm(now, name, "", Condition{Token: token}) + if err != nil { + t.Fatalf("iteration #%d: Confirm %q: %v", i, name, err) + } + if err := m.consistent(); err != nil { + t.Fatalf("iteration #%d: inconsistent state: %v", i, err) + } + release() + + case 1: + nRefresh++ + if _, err := m.Refresh(now, token, duration); err != nil { + t.Fatalf("iteration #%d: Refresh %q: %v", i, name, err) + } + + case 2: + unlocked = true + nUnlock++ + if err := m.Unlock(now, token); err != nil { + t.Fatalf("iteration #%d: Unlock %q: %v", i, name, err) + } + } + + } else { + nCreate++ + var err error + token, err = m.Create(now, LockDetails{ + Root: name, + Duration: duration, + ZeroDepth: lockTestZeroDepth(name), + }) + if err != nil { + t.Fatalf("iteration #%d: Create %q: %v", i, name, err) + } + } + + if !confirmed { + if duration == 0 || unlocked { + // A zero-duration lock should expire immediately and is + // effectively equivalent to being unlocked. + tokens[name] = "" + } else { + tokens[name] = token + } + } + + if err := m.consistent(); err != nil { + t.Fatalf("iteration #%d: inconsistent state: %v", i, err) + } + } + + if nConfirm < N/10 { + t.Fatalf("too few Confirm calls: got %d, want >= %d", nConfirm, N/10) + } + if nCreate < N/10 { + t.Fatalf("too few Create calls: got %d, want >= %d", nCreate, N/10) + } + if nRefresh < N/10 { + t.Fatalf("too few Refresh calls: got %d, want >= %d", nRefresh, N/10) + } + if nUnlock < N/10 { + t.Fatalf("too few Unlock calls: got %d, want >= %d", nUnlock, N/10) + } +} + +func (m *memLS) consistent() error { + m.mu.Lock() + defer m.mu.Unlock() + + // If m.byName is non-empty, then it must contain an entry for the root "/", + // and its refCount should equal the number of locked nodes. + if len(m.byName) > 0 { + n := m.byName["/"] + if n == nil { + return fmt.Errorf(`non-empty m.byName does not contain the root "/"`) + } + if n.refCount != len(m.byToken) { + return fmt.Errorf("root node refCount=%d, differs from len(m.byToken)=%d", n.refCount, len(m.byToken)) + } + } + + for name, n := range m.byName { + // The map keys should be consistent with the node's copy of the key. + if n.details.Root != name { + return fmt.Errorf("node name %q != byName map key %q", n.details.Root, name) + } + + // A name must be clean, and start with a "/". + if len(name) == 0 || name[0] != '/' { + return fmt.Errorf(`node name %q does not start with "/"`, name) + } + if name != path.Clean(name) { + return fmt.Errorf(`node name %q is not clean`, name) + } + + // A node's refCount should be positive. + if n.refCount <= 0 { + return fmt.Errorf("non-positive refCount for node at name %q", name) + } + + // A node's refCount should be the number of self-or-descendents that + // are locked (i.e. have a non-empty token). + var list []string + for name0, n0 := range m.byName { + // All of lockTestNames' name fragments are one byte long: '_', 'i' or 'z', + // so strings.HasPrefix is equivalent to self-or-descendent name match. + // We don't have to worry about "/foo/bar" being a false positive match + // for "/foo/b". + if strings.HasPrefix(name0, name) && n0.token != "" { + list = append(list, name0) + } + } + if n.refCount != len(list) { + sort.Strings(list) + return fmt.Errorf("node at name %q has refCount %d but locked self-or-descendents are %q (len=%d)", + name, n.refCount, list, len(list)) + } + + // A node n is in m.byToken if it has a non-empty token. + if n.token != "" { + if _, ok := m.byToken[n.token]; !ok { + return fmt.Errorf("node at name %q has token %q but not in m.byToken", name, n.token) + } + } + + // A node n is in m.byExpiry if it has a non-negative byExpiryIndex. + if n.byExpiryIndex >= 0 { + if n.byExpiryIndex >= len(m.byExpiry) { + return fmt.Errorf("node at name %q has byExpiryIndex %d but m.byExpiry has length %d", name, n.byExpiryIndex, len(m.byExpiry)) + } + if n != m.byExpiry[n.byExpiryIndex] { + return fmt.Errorf("node at name %q has byExpiryIndex %d but that indexes a different node", name, n.byExpiryIndex) + } + } + } + + for token, n := range m.byToken { + // The map keys should be consistent with the node's copy of the key. + if n.token != token { + return fmt.Errorf("node token %q != byToken map key %q", n.token, token) + } + + // Every node in m.byToken is in m.byName. + if _, ok := m.byName[n.details.Root]; !ok { + return fmt.Errorf("node at name %q in m.byToken but not in m.byName", n.details.Root) + } + } + + for i, n := range m.byExpiry { + // The slice indices should be consistent with the node's copy of the index. + if n.byExpiryIndex != i { + return fmt.Errorf("node byExpiryIndex %d != byExpiry slice index %d", n.byExpiryIndex, i) + } + + // Every node in m.byExpiry is in m.byName. + if _, ok := m.byName[n.details.Root]; !ok { + return fmt.Errorf("node at name %q in m.byExpiry but not in m.byName", n.details.Root) + } + + // No node in m.byExpiry should be held. + if n.held { + return fmt.Errorf("node at name %q in m.byExpiry is held", n.details.Root) + } + } + return nil +} + +func TestParseTimeout(t *testing.T) { + testCases := []struct { + s string + want time.Duration + wantErr error + }{{ + "", + infiniteTimeout, + nil, + }, { + "Infinite", + infiniteTimeout, + nil, + }, { + "Infinitesimal", + 0, + errInvalidTimeout, + }, { + "infinite", + 0, + errInvalidTimeout, + }, { + "Second-0", + 0 * time.Second, + nil, + }, { + "Second-123", + 123 * time.Second, + nil, + }, { + " Second-456 ", + 456 * time.Second, + nil, + }, { + "Second-4100000000", + 4100000000 * time.Second, + nil, + }, { + "junk", + 0, + errInvalidTimeout, + }, { + "Second-", + 0, + errInvalidTimeout, + }, { + "Second--1", + 0, + errInvalidTimeout, + }, { + "Second--123", + 0, + errInvalidTimeout, + }, { + "Second-+123", + 0, + errInvalidTimeout, + }, { + "Second-0x123", + 0, + errInvalidTimeout, + }, { + "second-123", + 0, + errInvalidTimeout, + }, { + "Second-4294967295", + 4294967295 * time.Second, + nil, + }, { + // Section 10.7 says that "The timeout value for TimeType "Second" + // must not be greater than 2^32-1." + "Second-4294967296", + 0, + errInvalidTimeout, + }, { + // This test case comes from section 9.10.9 of the spec. It says, + // + // "In this request, the client has specified that it desires an + // infinite-length lock, if available, otherwise a timeout of 4.1 + // billion seconds, if available." + // + // The Go WebDAV package always supports infinite length locks, + // and ignores the fallback after the comma. + "Infinite, Second-4100000000", + infiniteTimeout, + nil, + }} + + for _, tc := range testCases { + got, gotErr := parseTimeout(tc.s) + if got != tc.want || gotErr != tc.wantErr { + t.Errorf("parsing %q:\ngot %v, %v\nwant %v, %v", tc.s, got, gotErr, tc.want, tc.wantErr) + } + } +} diff --git a/server/webdav/prop.go b/server/webdav/prop.go new file mode 100644 index 0000000000000000000000000000000000000000..b1474ea3e9532621ea6a6833cc4deeb042f0f637 --- /dev/null +++ b/server/webdav/prop.go @@ -0,0 +1,485 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package webdav + +import ( + "bytes" + "context" + "encoding/xml" + "errors" + "fmt" + "mime" + "net/http" + "path" + "strconv" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/model" +) + +// Proppatch describes a property update instruction as defined in RFC 4918. +// See http://www.webdav.org/specs/rfc4918.html#METHOD_PROPPATCH +type Proppatch struct { + // Remove specifies whether this patch removes properties. If it does not + // remove them, it sets them. + Remove bool + // Props contains the properties to be set or removed. + Props []Property +} + +// Propstat describes a XML propstat element as defined in RFC 4918. +// See http://www.webdav.org/specs/rfc4918.html#ELEMENT_propstat +type Propstat struct { + // Props contains the properties for which Status applies. + Props []Property + + // Status defines the HTTP status code of the properties in Prop. + // Allowed values include, but are not limited to the WebDAV status + // code extensions for HTTP/1.1. + // http://www.webdav.org/specs/rfc4918.html#status.code.extensions.to.http11 + Status int + + // XMLError contains the XML representation of the optional error element. + // XML content within this field must not rely on any predefined + // namespace declarations or prefixes. If empty, the XML error element + // is omitted. + XMLError string + + // ResponseDescription contains the contents of the optional + // responsedescription field. If empty, the XML element is omitted. + ResponseDescription string +} + +// makePropstats returns a slice containing those of x and y whose Props slice +// is non-empty. If both are empty, it returns a slice containing an otherwise +// zero Propstat whose HTTP status code is 200 OK. +func makePropstats(x, y Propstat) []Propstat { + pstats := make([]Propstat, 0, 2) + if len(x.Props) != 0 { + pstats = append(pstats, x) + } + if len(y.Props) != 0 { + pstats = append(pstats, y) + } + if len(pstats) == 0 { + pstats = append(pstats, Propstat{ + Status: http.StatusOK, + }) + } + return pstats +} + +// DeadPropsHolder holds the dead properties of a resource. +// +// Dead properties are those properties that are explicitly defined. In +// comparison, live properties, such as DAV:getcontentlength, are implicitly +// defined by the underlying resource, and cannot be explicitly overridden or +// removed. See the Terminology section of +// http://www.webdav.org/specs/rfc4918.html#rfc.section.3 +// +// There is a whitelist of the names of live properties. This package handles +// all live properties, and will only pass non-whitelisted names to the Patch +// method of DeadPropsHolder implementations. +type DeadPropsHolder interface { + // DeadProps returns a copy of the dead properties held. + DeadProps() (map[xml.Name]Property, error) + + // Patch patches the dead properties held. + // + // Patching is atomic; either all or no patches succeed. It returns (nil, + // non-nil) if an internal server error occurred, otherwise the Propstats + // collectively contain one Property for each proposed patch Property. If + // all patches succeed, Patch returns a slice of length one and a Propstat + // element with a 200 OK HTTP status code. If none succeed, for reasons + // other than an internal server error, no Propstat has status 200 OK. + // + // For more details on when various HTTP status codes apply, see + // http://www.webdav.org/specs/rfc4918.html#PROPPATCH-status + Patch([]Proppatch) ([]Propstat, error) +} + +// liveProps contains all supported, protected DAV: properties. +var liveProps = map[xml.Name]struct { + // findFn implements the propfind function of this property. If nil, + // it indicates a hidden property. + findFn func(context.Context, LockSystem, string, model.Obj) (string, error) + // dir is true if the property applies to directories. + dir bool +}{ + {Space: "DAV:", Local: "resourcetype"}: { + findFn: findResourceType, + dir: true, + }, + {Space: "DAV:", Local: "displayname"}: { + findFn: findDisplayName, + dir: true, + }, + {Space: "DAV:", Local: "getcontentlength"}: { + findFn: findContentLength, + dir: false, + }, + {Space: "DAV:", Local: "getlastmodified"}: { + findFn: findLastModified, + // http://webdav.org/specs/rfc4918.html#PROPERTY_getlastmodified + // suggests that getlastmodified should only apply to GETable + // resources, and this package does not support GET on directories. + // + // Nonetheless, some WebDAV clients expect child directories to be + // sortable by getlastmodified date, so this value is true, not false. + // See golang.org/issue/15334. + dir: true, + }, + {Space: "DAV:", Local: "creationdate"}: { + findFn: findCreationDate, + dir: true, + }, + {Space: "DAV:", Local: "getcontentlanguage"}: { + findFn: nil, + dir: false, + }, + {Space: "DAV:", Local: "getcontenttype"}: { + findFn: findContentType, + dir: false, + }, + {Space: "DAV:", Local: "getetag"}: { + findFn: findETag, + // findETag implements ETag as the concatenated hex values of a file's + // modification time and size. This is not a reliable synchronization + // mechanism for directories, so we do not advertise getetag for DAV + // collections. + dir: false, + }, + + // TODO: The lockdiscovery property requires LockSystem to list the + // active locks on a resource. + {Space: "DAV:", Local: "lockdiscovery"}: {}, + {Space: "DAV:", Local: "supportedlock"}: { + findFn: findSupportedLock, + dir: true, + }, +} + +// TODO(nigeltao) merge props and allprop? + +// Props returns the status of the properties named pnames for resource name. +// +// Each Propstat has a unique status and each property name will only be part +// of one Propstat element. +func props(ctx context.Context, ls LockSystem, fi model.Obj, pnames []xml.Name) ([]Propstat, error) { + //f, err := fs.OpenFile(ctx, name, os.O_RDONLY, 0) + //if err != nil { + // return nil, err + //} + //defer f.Close() + //fi, err := f.Stat() + //if err != nil { + // return nil, err + //} + isDir := fi.IsDir() + + var deadProps map[xml.Name]Property + // ??? what is this for? + //if dph, ok := f.(DeadPropsHolder); ok { + // deadProps, err = dph.DeadProps() + // if err != nil { + // return nil, err + // } + //} + + pstatOK := Propstat{Status: http.StatusOK} + pstatNotFound := Propstat{Status: http.StatusNotFound} + for _, pn := range pnames { + // If this file has dead properties, check if they contain pn. + if dp, ok := deadProps[pn]; ok { + pstatOK.Props = append(pstatOK.Props, dp) + continue + } + // Otherwise, it must either be a live property or we don't know it. + if prop := liveProps[pn]; prop.findFn != nil && (prop.dir || !isDir) { + innerXML, err := prop.findFn(ctx, ls, fi.GetName(), fi) + if err != nil { + return nil, err + } + pstatOK.Props = append(pstatOK.Props, Property{ + XMLName: pn, + InnerXML: []byte(innerXML), + }) + } else { + pstatNotFound.Props = append(pstatNotFound.Props, Property{ + XMLName: pn, + }) + } + } + return makePropstats(pstatOK, pstatNotFound), nil +} + +// Propnames returns the property names defined for resource name. +func propnames(ctx context.Context, ls LockSystem, fi model.Obj) ([]xml.Name, error) { + //f, err := fs.OpenFile(ctx, name, os.O_RDONLY, 0) + //if err != nil { + // return nil, err + //} + //defer f.Close() + //fi, err := f.Stat() + //if err != nil { + // return nil, err + //} + isDir := fi.IsDir() + + var deadProps map[xml.Name]Property + // ??? what is this for? + //if dph, ok := f.(DeadPropsHolder); ok { + // deadProps, err = dph.DeadProps() + // if err != nil { + // return nil, err + // } + //} + + pnames := make([]xml.Name, 0, len(liveProps)+len(deadProps)) + for pn, prop := range liveProps { + if prop.findFn != nil && (prop.dir || !isDir) { + pnames = append(pnames, pn) + } + } + for pn := range deadProps { + pnames = append(pnames, pn) + } + return pnames, nil +} + +// Allprop returns the properties defined for resource name and the properties +// named in include. +// +// Note that RFC 4918 defines 'allprop' to return the DAV: properties defined +// within the RFC plus dead properties. Other live properties should only be +// returned if they are named in 'include'. +// +// See http://www.webdav.org/specs/rfc4918.html#METHOD_PROPFIND +func allprop(ctx context.Context, ls LockSystem, fi model.Obj, include []xml.Name) ([]Propstat, error) { + pnames, err := propnames(ctx, ls, fi) + if err != nil { + return nil, err + } + // Add names from include if they are not already covered in pnames. + nameset := make(map[xml.Name]bool) + for _, pn := range pnames { + nameset[pn] = true + } + for _, pn := range include { + if !nameset[pn] { + pnames = append(pnames, pn) + } + } + return props(ctx, ls, fi, pnames) +} + +// Patch patches the properties of resource name. The return values are +// constrained in the same manner as DeadPropsHolder.Patch. +func patch(ctx context.Context, ls LockSystem, name string, patches []Proppatch) ([]Propstat, error) { + conflict := false +loop: + for _, patch := range patches { + for _, p := range patch.Props { + if _, ok := liveProps[p.XMLName]; ok { + conflict = true + break loop + } + } + } + if conflict { + pstatForbidden := Propstat{ + Status: http.StatusForbidden, + XMLError: ``, + } + pstatFailedDep := Propstat{ + Status: StatusFailedDependency, + } + for _, patch := range patches { + for _, p := range patch.Props { + if _, ok := liveProps[p.XMLName]; ok { + pstatForbidden.Props = append(pstatForbidden.Props, Property{XMLName: p.XMLName}) + } else { + pstatFailedDep.Props = append(pstatFailedDep.Props, Property{XMLName: p.XMLName}) + } + } + } + return makePropstats(pstatForbidden, pstatFailedDep), nil + } + + // ------------------------------------------------------------ + //f, err := fs.OpenFile(ctx, name, os.O_RDWR, 0) + //if err != nil { + // return nil, err + //} + //defer f.Close() + //if dph, ok := f.(DeadPropsHolder); ok { + // ret, err := dph.Patch(patches) + // if err != nil { + // return nil, err + // } + // // http://www.webdav.org/specs/rfc4918.html#ELEMENT_propstat says that + // // "The contents of the prop XML element must only list the names of + // // properties to which the result in the status element applies." + // for _, pstat := range ret { + // for i, p := range pstat.Props { + // pstat.Props[i] = Property{XMLName: p.XMLName} + // } + // } + // return ret, nil + //} + // ------------------------------------------------------------ + + // The file doesn't implement the optional DeadPropsHolder interface, so + // all patches are forbidden. + pstat := Propstat{Status: http.StatusForbidden} + for _, patch := range patches { + for _, p := range patch.Props { + pstat.Props = append(pstat.Props, Property{XMLName: p.XMLName}) + } + } + return []Propstat{pstat}, nil +} + +func escapeXML(s string) string { + for i := 0; i < len(s); i++ { + // As an optimization, if s contains only ASCII letters, digits or a + // few special characters, the escaped value is s itself and we don't + // need to allocate a buffer and convert between string and []byte. + switch c := s[i]; { + case c == ' ' || c == '_' || + ('+' <= c && c <= '9') || // Digits as well as + , - . and / + ('A' <= c && c <= 'Z') || + ('a' <= c && c <= 'z'): + continue + } + // Otherwise, go through the full escaping process. + var buf bytes.Buffer + xml.EscapeText(&buf, []byte(s)) + return buf.String() + } + return s +} + +func findResourceType(ctx context.Context, ls LockSystem, name string, fi model.Obj) (string, error) { + if fi.IsDir() { + return ``, nil + } + return "", nil +} + +func findDisplayName(ctx context.Context, ls LockSystem, name string, fi model.Obj) (string, error) { + if slashClean(name) == "/" { + // Hide the real name of a possibly prefixed root directory. + return "", nil + } + return escapeXML(fi.GetName()), nil +} + +func findContentLength(ctx context.Context, ls LockSystem, name string, fi model.Obj) (string, error) { + return strconv.FormatInt(fi.GetSize(), 10), nil +} + +func findLastModified(ctx context.Context, ls LockSystem, name string, fi model.Obj) (string, error) { + return fi.ModTime().UTC().Format(http.TimeFormat), nil +} +func findCreationDate(ctx context.Context, ls LockSystem, name string, fi model.Obj) (string, error) { + userAgent := ctx.Value("userAgent").(string) + if strings.Contains(strings.ToLower(userAgent), "microsoft-webdav") { + return fi.CreateTime().UTC().Format(http.TimeFormat), nil + } + return fi.CreateTime().UTC().Format(time.RFC3339), nil +} + +// ErrNotImplemented should be returned by optional interfaces if they +// want the original implementation to be used. +var ErrNotImplemented = errors.New("not implemented") + +// ContentTyper is an optional interface for the os.FileInfo +// objects returned by the FileSystem. +// +// If this interface is defined then it will be used to read the +// content type from the object. +// +// If this interface is not defined the file will be opened and the +// content type will be guessed from the initial contents of the file. +type ContentTyper interface { + // ContentType returns the content type for the file. + // + // If this returns error ErrNotImplemented then the error will + // be ignored and the base implementation will be used + // instead. + ContentType(ctx context.Context) (string, error) +} + +func findContentType(ctx context.Context, ls LockSystem, name string, fi model.Obj) (string, error) { + //if do, ok := fi.(ContentTyper); ok { + // ctype, err := do.ContentType(ctx) + // if err != ErrNotImplemented { + // return ctype, err + // } + //} + //f, err := fs.OpenFile(ctx, name, os.O_RDONLY, 0) + //if err != nil { + // return "", err + //} + //defer f.Close() + // This implementation is based on serveContent's code in the standard net/http package. + ctype := mime.TypeByExtension(path.Ext(name)) + return ctype, nil + //if ctype != "" { + // return ctype, nil + //} + //return "application/octet-stream", nil + // Read a chunk to decide between utf-8 text and binary. + //var buf [512]byte + //n, err := io.ReadFull(f, buf[:]) + //if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + // return "", err + //} + //ctype = http.DetectContentType(buf[:n]) + //// Rewind file. + //_, err = f.Seek(0, os.SEEK_SET) + //return ctype, err +} + +// ETager is an optional interface for the os.FileInfo objects +// returned by the FileSystem. +// +// If this interface is defined then it will be used to read the ETag +// for the object. +// +// If this interface is not defined an ETag will be computed using the +// ModTime() and the Size() methods of the os.FileInfo object. +type ETager interface { + // ETag returns an ETag for the file. This should be of the + // form "value" or W/"value" + // + // If this returns error ErrNotImplemented then the error will + // be ignored and the base implementation will be used + // instead. + ETag(ctx context.Context) (string, error) +} + +func findETag(ctx context.Context, ls LockSystem, name string, fi model.Obj) (string, error) { + if do, ok := fi.(ETager); ok { + etag, err := do.ETag(ctx) + if !errors.Is(err, ErrNotImplemented) { + return etag, err + } + } + // The Apache http 2.4 web server by default concatenates the + // modification time and size of a file. We replicate the heuristic + // with nanosecond granularity. + return fmt.Sprintf(`"%x%x"`, fi.ModTime().UnixNano(), fi.GetSize()), nil +} + +func findSupportedLock(ctx context.Context, ls LockSystem, name string, fi model.Obj) (string, error) { + return `` + + `` + + `` + + `` + + ``, nil +} diff --git a/server/webdav/util.go b/server/webdav/util.go new file mode 100644 index 0000000000000000000000000000000000000000..15d9e07cc56065be0f99ae9709b95b770dc3fd01 --- /dev/null +++ b/server/webdav/util.go @@ -0,0 +1,29 @@ +package webdav + +import ( + log "github.com/sirupsen/logrus" + "net/http" + "strconv" + "time" +) + +func (h *Handler) getModTime(r *http.Request) time.Time { + return h.getHeaderTime(r, "X-OC-Mtime") +} + +// owncloud/ nextcloud haven't impl this, but we can add the support since rclone may support this soon +func (h *Handler) getCreateTime(r *http.Request) time.Time { + return h.getHeaderTime(r, "X-OC-Ctime") +} + +func (h *Handler) getHeaderTime(r *http.Request, header string) time.Time { + hVal := r.Header.Get(header) + if hVal != "" { + modTimeUnix, err := strconv.ParseInt(hVal, 10, 64) + if err == nil { + return time.Unix(modTimeUnix, 0) + } + log.Warnf("getModTime in Webdav, failed to parse %s, %s", header, err) + } + return time.Now() +} diff --git a/server/webdav/webdav.go b/server/webdav/webdav.go new file mode 100644 index 0000000000000000000000000000000000000000..390e540997619eca1a8d1e6f48795994903d4c06 --- /dev/null +++ b/server/webdav/webdav.go @@ -0,0 +1,826 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package webdav provides a WebDAV server implementation. +package webdav // import "golang.org/x/net/webdav" + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "os" + "path" + "strings" + "time" + + "github.com/alist-org/alist/v3/internal/stream" + + "github.com/alist-org/alist/v3/internal/errs" + "github.com/alist-org/alist/v3/internal/fs" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/sign" + "github.com/alist-org/alist/v3/pkg/utils" + "github.com/alist-org/alist/v3/server/common" + log "github.com/sirupsen/logrus" +) + +type Handler struct { + // Prefix is the URL path prefix to strip from WebDAV resource paths. + Prefix string + // LockSystem is the lock management system. + LockSystem LockSystem + // Logger is an optional error logger. If non-nil, it will be called + // for all HTTP requests. + Logger func(*http.Request, error) +} + +func (h *Handler) stripPrefix(p string) (string, int, error) { + if h.Prefix == "" { + return p, http.StatusOK, nil + } + if r := strings.TrimPrefix(p, h.Prefix); len(r) < len(p) { + return r, http.StatusOK, nil + } + return p, http.StatusNotFound, errPrefixMismatch +} + +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + status, err := http.StatusBadRequest, errUnsupportedMethod + brw := newBufferedResponseWriter() + useBufferedWriter := true + if h.LockSystem == nil { + status, err = http.StatusInternalServerError, errNoLockSystem + } else { + switch r.Method { + case "OPTIONS": + status, err = h.handleOptions(brw, r) + case "GET", "HEAD", "POST": + useBufferedWriter = false + status, err = h.handleGetHeadPost(w, r) + case "DELETE": + status, err = h.handleDelete(brw, r) + case "PUT": + status, err = h.handlePut(brw, r) + case "MKCOL": + status, err = h.handleMkcol(brw, r) + case "COPY", "MOVE": + status, err = h.handleCopyMove(brw, r) + case "LOCK": + status, err = h.handleLock(brw, r) + case "UNLOCK": + status, err = h.handleUnlock(brw, r) + case "PROPFIND": + status, err = h.handlePropfind(brw, r) + // if there is a error for PROPFIND, we should be as an empty folder to the client + if err != nil { + status = http.StatusNotFound + } + case "PROPPATCH": + status, err = h.handleProppatch(brw, r) + } + } + + if status != 0 { + w.WriteHeader(status) + if status != http.StatusNoContent { + w.Write([]byte(StatusText(status))) + } + } else if useBufferedWriter { + brw.WriteToResponse(w) + } + if h.Logger != nil && err != nil { + h.Logger(r, err) + } +} + +func (h *Handler) lock(now time.Time, root string) (token string, status int, err error) { + token, err = h.LockSystem.Create(now, LockDetails{ + Root: root, + Duration: infiniteTimeout, + ZeroDepth: true, + }) + if err != nil { + if err == ErrLocked { + return "", StatusLocked, err + } + return "", http.StatusInternalServerError, err + } + return token, 0, nil +} + +func (h *Handler) confirmLocks(r *http.Request, src, dst string) (release func(), status int, err error) { + hdr := r.Header.Get("If") + if hdr == "" { + // An empty If header means that the client hasn't previously created locks. + // Even if this client doesn't care about locks, we still need to check that + // the resources aren't locked by another client, so we create temporary + // locks that would conflict with another client's locks. These temporary + // locks are unlocked at the end of the HTTP request. + now, srcToken, dstToken := time.Now(), "", "" + if src != "" { + srcToken, status, err = h.lock(now, src) + if err != nil { + return nil, status, err + } + } + if dst != "" { + dstToken, status, err = h.lock(now, dst) + if err != nil { + if srcToken != "" { + h.LockSystem.Unlock(now, srcToken) + } + return nil, status, err + } + } + + return func() { + if dstToken != "" { + h.LockSystem.Unlock(now, dstToken) + } + if srcToken != "" { + h.LockSystem.Unlock(now, srcToken) + } + }, 0, nil + } + + ih, ok := parseIfHeader(hdr) + if !ok { + return nil, http.StatusBadRequest, errInvalidIfHeader + } + // ih is a disjunction (OR) of ifLists, so any ifList will do. + for _, l := range ih.lists { + lsrc := l.resourceTag + if lsrc == "" { + lsrc = src + } else { + u, err := url.Parse(lsrc) + if err != nil { + continue + } + if u.Host != r.Host { + continue + } + lsrc, status, err = h.stripPrefix(u.Path) + if err != nil { + return nil, status, err + } + } + release, err = h.LockSystem.Confirm(time.Now(), lsrc, dst, l.conditions...) + if err == ErrConfirmationFailed { + continue + } + if err != nil { + return nil, http.StatusInternalServerError, err + } + return release, 0, nil + } + // Section 10.4.1 says that "If this header is evaluated and all state lists + // fail, then the request must fail with a 412 (Precondition Failed) status." + // We follow the spec even though the cond_put_corrupt_token test case from + // the litmus test warns on seeing a 412 instead of a 423 (Locked). + return nil, http.StatusPreconditionFailed, ErrLocked +} + +func (h *Handler) handleOptions(w http.ResponseWriter, r *http.Request) (status int, err error) { + reqPath, status, err := h.stripPrefix(r.URL.Path) + if err != nil { + return status, err + } + ctx := r.Context() + user := ctx.Value("user").(*model.User) + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return 403, err + } + allow := "OPTIONS, LOCK, PUT, MKCOL" + if fi, err := fs.Get(ctx, reqPath, &fs.GetArgs{}); err == nil { + if fi.IsDir() { + allow = "OPTIONS, LOCK, DELETE, PROPPATCH, COPY, MOVE, UNLOCK, PROPFIND" + } else { + allow = "OPTIONS, LOCK, GET, HEAD, POST, DELETE, PROPPATCH, COPY, MOVE, UNLOCK, PROPFIND, PUT" + } + } + w.Header().Set("Allow", allow) + // http://www.webdav.org/specs/rfc4918.html#dav.compliance.classes + w.Header().Set("DAV", "1, 2") + // http://msdn.microsoft.com/en-au/library/cc250217.aspx + w.Header().Set("MS-Author-Via", "DAV") + return 0, nil +} + +func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (status int, err error) { + reqPath, status, err := h.stripPrefix(r.URL.Path) + if err != nil { + return status, err + } + // TODO: check locks for read-only access?? + ctx := r.Context() + user := ctx.Value("user").(*model.User) + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return http.StatusForbidden, err + } + fi, err := fs.Get(ctx, reqPath, &fs.GetArgs{}) + if err != nil { + return http.StatusNotFound, err + } + etag, err := findETag(ctx, h.LockSystem, reqPath, fi) + if err != nil { + return http.StatusInternalServerError, err + } + w.Header().Set("ETag", etag) + if r.Method == http.MethodHead { + w.Header().Set("Content-Length", fmt.Sprintf("%d", fi.GetSize())) + return http.StatusOK, nil + } + if fi.IsDir() { + return http.StatusMethodNotAllowed, nil + } + // Let ServeContent determine the Content-Type header. + storage, _ := fs.GetStorage(reqPath, &fs.GetStoragesArgs{}) + downProxyUrl := storage.GetStorage().DownProxyUrl + if storage.GetStorage().WebdavNative() || (storage.GetStorage().WebdavProxy() && downProxyUrl == "") { + link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{Header: r.Header, HttpReq: r}) + if err != nil { + return http.StatusInternalServerError, err + } + err = common.Proxy(w, r, link, fi) + if err != nil { + log.Errorf("webdav proxy error: %+v", err) + return http.StatusInternalServerError, err + } + } else if storage.GetStorage().WebdavProxy() && downProxyUrl != "" { + u := fmt.Sprintf("%s%s?sign=%s", + strings.Split(downProxyUrl, "\n")[0], + utils.EncodePath(reqPath, true), + sign.Sign(reqPath)) + w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate") + http.Redirect(w, r, u, http.StatusFound) + } else { + link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{IP: utils.ClientIP(r), Header: r.Header, HttpReq: r}) + if err != nil { + return http.StatusInternalServerError, err + } + http.Redirect(w, r, link.URL, http.StatusFound) + } + return 0, nil +} + +func (h *Handler) handleDelete(w http.ResponseWriter, r *http.Request) (status int, err error) { + reqPath, status, err := h.stripPrefix(r.URL.Path) + if err != nil { + return status, err + } + release, status, err := h.confirmLocks(r, reqPath, "") + if err != nil { + return status, err + } + defer release() + + ctx := r.Context() + user := ctx.Value("user").(*model.User) + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return 403, err + } + // TODO: return MultiStatus where appropriate. + + // "godoc os RemoveAll" says that "If the path does not exist, RemoveAll + // returns nil (no error)." WebDAV semantics are that it should return a + // "404 Not Found". We therefore have to Stat before we RemoveAll. + if _, err := fs.Get(ctx, reqPath, &fs.GetArgs{}); err != nil { + if errs.IsObjectNotFound(err) { + return http.StatusNotFound, err + } + return http.StatusMethodNotAllowed, err + } + if err := fs.Remove(ctx, reqPath); err != nil { + return http.StatusMethodNotAllowed, err + } + //fs.ClearCache(path.Dir(reqPath)) + return http.StatusNoContent, nil +} + +func (h *Handler) handlePut(w http.ResponseWriter, r *http.Request) (status int, err error) { + reqPath, status, err := h.stripPrefix(r.URL.Path) + if err != nil { + return status, err + } + if reqPath == "" { + return http.StatusMethodNotAllowed, nil + } + release, status, err := h.confirmLocks(r, reqPath, "") + if err != nil { + return status, err + } + defer release() + // TODO(rost): Support the If-Match, If-None-Match headers? See bradfitz' + // comments in http.checkEtag. + ctx := r.Context() + user := ctx.Value("user").(*model.User) + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return http.StatusForbidden, err + } + obj := model.Object{ + Name: path.Base(reqPath), + Size: r.ContentLength, + Modified: h.getModTime(r), + Ctime: h.getCreateTime(r), + } + stream := &stream.FileStream{ + Obj: &obj, + Reader: r.Body, + Mimetype: r.Header.Get("Content-Type"), + } + if stream.Mimetype == "" { + stream.Mimetype = utils.GetMimeType(reqPath) + } + err = fs.PutDirectly(ctx, path.Dir(reqPath), stream) + if errs.IsNotFoundError(err) { + return http.StatusNotFound, err + } + + _ = r.Body.Close() + _ = stream.Close() + // TODO(rost): Returning 405 Method Not Allowed might not be appropriate. + if err != nil { + return http.StatusMethodNotAllowed, err + } + fi, err := fs.Get(ctx, reqPath, &fs.GetArgs{}) + if err != nil { + fi = &obj + } + etag, err := findETag(ctx, h.LockSystem, reqPath, fi) + if err != nil { + return http.StatusInternalServerError, err + } + w.Header().Set("ETag", etag) + return http.StatusCreated, nil +} + +func (h *Handler) handleMkcol(w http.ResponseWriter, r *http.Request) (status int, err error) { + reqPath, status, err := h.stripPrefix(r.URL.Path) + if err != nil { + return status, err + } + release, status, err := h.confirmLocks(r, reqPath, "") + if err != nil { + return status, err + } + defer release() + + ctx := r.Context() + user := ctx.Value("user").(*model.User) + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return 403, err + } + + if r.ContentLength > 0 { + return http.StatusUnsupportedMediaType, nil + } + + // RFC 4918 9.3.1 + //405 (Method Not Allowed) - MKCOL can only be executed on an unmapped URL + if _, err := fs.Get(ctx, reqPath, &fs.GetArgs{}); err == nil { + return http.StatusMethodNotAllowed, err + } + // RFC 4918 9.3.1 + // 409 (Conflict) The server MUST NOT create those intermediate collections automatically. + reqDir := path.Dir(reqPath) + if _, err := fs.Get(ctx, reqDir, &fs.GetArgs{}); err != nil { + if errs.IsObjectNotFound(err) { + return http.StatusConflict, err + } + return http.StatusMethodNotAllowed, err + } + if err := fs.MakeDir(ctx, reqPath); err != nil { + if os.IsNotExist(err) { + return http.StatusConflict, err + } + return http.StatusMethodNotAllowed, err + } + return http.StatusCreated, nil +} + +func (h *Handler) handleCopyMove(w http.ResponseWriter, r *http.Request) (status int, err error) { + hdr := r.Header.Get("Destination") + if hdr == "" { + return http.StatusBadRequest, errInvalidDestination + } + u, err := url.Parse(hdr) + if err != nil { + return http.StatusBadRequest, errInvalidDestination + } + if u.Host != "" && u.Host != r.Host { + return http.StatusBadGateway, errInvalidDestination + } + + src, status, err := h.stripPrefix(r.URL.Path) + if err != nil { + return status, err + } + + dst, status, err := h.stripPrefix(u.Path) + if err != nil { + return status, err + } + + if dst == "" { + return http.StatusBadGateway, errInvalidDestination + } + if dst == src { + return http.StatusForbidden, errDestinationEqualsSource + } + + ctx := r.Context() + user := ctx.Value("user").(*model.User) + src, err = user.JoinPath(src) + if err != nil { + return 403, err + } + dst, err = user.JoinPath(dst) + if err != nil { + return 403, err + } + + if r.Method == "COPY" { + // Section 7.5.1 says that a COPY only needs to lock the destination, + // not both destination and source. Strictly speaking, this is racy, + // even though a COPY doesn't modify the source, if a concurrent + // operation modifies the source. However, the litmus test explicitly + // checks that COPYing a locked-by-another source is OK. + release, status, err := h.confirmLocks(r, "", dst) + if err != nil { + return status, err + } + defer release() + + // Section 9.8.3 says that "The COPY method on a collection without a Depth + // header must act as if a Depth header with value "infinity" was included". + depth := infiniteDepth + if hdr := r.Header.Get("Depth"); hdr != "" { + depth = parseDepth(hdr) + if depth != 0 && depth != infiniteDepth { + // Section 9.8.3 says that "A client may submit a Depth header on a + // COPY on a collection with a value of "0" or "infinity"." + return http.StatusBadRequest, errInvalidDepth + } + } + return copyFiles(ctx, src, dst, r.Header.Get("Overwrite") != "F") + } + + release, status, err := h.confirmLocks(r, src, dst) + if err != nil { + return status, err + } + defer release() + + // Section 9.9.2 says that "The MOVE method on a collection must act as if + // a "Depth: infinity" header was used on it. A client must not submit a + // Depth header on a MOVE on a collection with any value but "infinity"." + if hdr := r.Header.Get("Depth"); hdr != "" { + if parseDepth(hdr) != infiniteDepth { + return http.StatusBadRequest, errInvalidDepth + } + } + return moveFiles(ctx, src, dst, r.Header.Get("Overwrite") == "T") +} + +func (h *Handler) handleLock(w http.ResponseWriter, r *http.Request) (retStatus int, retErr error) { + duration, err := parseTimeout(r.Header.Get("Timeout")) + if err != nil { + return http.StatusBadRequest, err + } + li, status, err := readLockInfo(r.Body) + if err != nil { + return status, err + } + + ctx := r.Context() + user := ctx.Value("user").(*model.User) + token, ld, now, created := "", LockDetails{}, time.Now(), false + if li == (lockInfo{}) { + // An empty lockInfo means to refresh the lock. + ih, ok := parseIfHeader(r.Header.Get("If")) + if !ok { + return http.StatusBadRequest, errInvalidIfHeader + } + if len(ih.lists) == 1 && len(ih.lists[0].conditions) == 1 { + token = ih.lists[0].conditions[0].Token + } + if token == "" { + return http.StatusBadRequest, errInvalidLockToken + } + ld, err = h.LockSystem.Refresh(now, token, duration) + if err != nil { + if err == ErrNoSuchLock { + return http.StatusPreconditionFailed, err + } + return http.StatusInternalServerError, err + } + + } else { + // Section 9.10.3 says that "If no Depth header is submitted on a LOCK request, + // then the request MUST act as if a "Depth:infinity" had been submitted." + depth := infiniteDepth + if hdr := r.Header.Get("Depth"); hdr != "" { + depth = parseDepth(hdr) + if depth != 0 && depth != infiniteDepth { + // Section 9.10.3 says that "Values other than 0 or infinity must not be + // used with the Depth header on a LOCK method". + return http.StatusBadRequest, errInvalidDepth + } + } + reqPath, status, err := h.stripPrefix(r.URL.Path) + if err != nil { + return status, err + } + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return 403, err + } + ld = LockDetails{ + Root: reqPath, + Duration: duration, + OwnerXML: li.Owner.InnerXML, + ZeroDepth: depth == 0, + } + token, err = h.LockSystem.Create(now, ld) + if err != nil { + if err == ErrLocked { + return StatusLocked, err + } + return http.StatusInternalServerError, err + } + defer func() { + if retErr != nil { + h.LockSystem.Unlock(now, token) + } + }() + + // ??? Why create resource here? + //// Create the resource if it didn't previously exist. + //if _, err := h.FileSystem.Stat(ctx, reqPath); err != nil { + // f, err := h.FileSystem.OpenFile(ctx, reqPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666) + // if err != nil { + // // TODO: detect missing intermediate dirs and return http.StatusConflict? + // return http.StatusInternalServerError, err + // } + // f.Close() + // created = true + //} + + // http://www.webdav.org/specs/rfc4918.html#HEADER_Lock-Token says that the + // Lock-Token value is a Coded-URL. We add angle brackets. + w.Header().Set("Lock-Token", "<"+token+">") + } + + w.Header().Set("Content-Type", "application/xml; charset=utf-8") + if created { + // This is "w.WriteHeader(http.StatusCreated)" and not "return + // http.StatusCreated, nil" because we write our own (XML) response to w + // and Handler.ServeHTTP would otherwise write "Created". + w.WriteHeader(http.StatusCreated) + } + writeLockInfo(w, token, ld) + return 0, nil +} + +func (h *Handler) handleUnlock(w http.ResponseWriter, r *http.Request) (status int, err error) { + // http://www.webdav.org/specs/rfc4918.html#HEADER_Lock-Token says that the + // Lock-Token value is a Coded-URL. We strip its angle brackets. + t := r.Header.Get("Lock-Token") + if len(t) < 2 || t[0] != '<' || t[len(t)-1] != '>' { + return http.StatusBadRequest, errInvalidLockToken + } + t = t[1 : len(t)-1] + + switch err = h.LockSystem.Unlock(time.Now(), t); err { + case nil: + return http.StatusNoContent, err + case ErrForbidden: + return http.StatusForbidden, err + case ErrLocked: + return StatusLocked, err + case ErrNoSuchLock: + return http.StatusConflict, err + default: + return http.StatusInternalServerError, err + } +} + +func (h *Handler) handlePropfind(w http.ResponseWriter, r *http.Request) (status int, err error) { + reqPath, status, err := h.stripPrefix(r.URL.Path) + if err != nil { + return status, err + } + ctx := r.Context() + userAgent := r.Header.Get("User-Agent") + ctx = context.WithValue(ctx, "userAgent", userAgent) + user := ctx.Value("user").(*model.User) + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return 403, err + } + fi, err := fs.Get(ctx, reqPath, &fs.GetArgs{}) + if err != nil { + if errs.IsNotFoundError(err) { + return http.StatusNotFound, err + } + return http.StatusMethodNotAllowed, err + } + depth := infiniteDepth + if hdr := r.Header.Get("Depth"); hdr != "" { + depth = parseDepth(hdr) + if depth == invalidDepth { + return http.StatusBadRequest, errInvalidDepth + } + } + pf, status, err := readPropfind(r.Body) + if err != nil { + return status, err + } + + mw := multistatusWriter{w: w} + + walkFn := func(reqPath string, info model.Obj, err error) error { + if err != nil { + return err + } + var pstats []Propstat + if pf.Propname != nil { + pnames, err := propnames(ctx, h.LockSystem, info) + if err != nil { + return err + } + pstat := Propstat{Status: http.StatusOK} + for _, xmlname := range pnames { + pstat.Props = append(pstat.Props, Property{XMLName: xmlname}) + } + pstats = append(pstats, pstat) + } else if pf.Allprop != nil { + pstats, err = allprop(ctx, h.LockSystem, info, pf.Prop) + } else { + pstats, err = props(ctx, h.LockSystem, info, pf.Prop) + } + if err != nil { + return err + } + href := path.Join(h.Prefix, strings.TrimPrefix(reqPath, user.BasePath)) + if href != "/" && info.IsDir() { + href += "/" + } + return mw.write(makePropstatResponse(href, pstats)) + } + + walkErr := walkFS(ctx, depth, reqPath, fi, walkFn) + closeErr := mw.close() + if walkErr != nil { + return http.StatusInternalServerError, walkErr + } + if closeErr != nil { + return http.StatusInternalServerError, closeErr + } + return 0, nil +} + +func (h *Handler) handleProppatch(w http.ResponseWriter, r *http.Request) (status int, err error) { + reqPath, status, err := h.stripPrefix(r.URL.Path) + if err != nil { + return status, err + } + release, status, err := h.confirmLocks(r, reqPath, "") + if err != nil { + return status, err + } + defer release() + + ctx := r.Context() + user := ctx.Value("user").(*model.User) + reqPath, err = user.JoinPath(reqPath) + if err != nil { + return 403, err + } + if _, err := fs.Get(ctx, reqPath, &fs.GetArgs{}); err != nil { + if errs.IsObjectNotFound(err) { + return http.StatusNotFound, err + } + return http.StatusMethodNotAllowed, err + } + patches, status, err := readProppatch(r.Body) + if err != nil { + return status, err + } + pstats, err := patch(ctx, h.LockSystem, reqPath, patches) + if err != nil { + return http.StatusInternalServerError, err + } + mw := multistatusWriter{w: w} + writeErr := mw.write(makePropstatResponse(r.URL.Path, pstats)) + closeErr := mw.close() + if writeErr != nil { + return http.StatusInternalServerError, writeErr + } + if closeErr != nil { + return http.StatusInternalServerError, closeErr + } + return 0, nil +} + +func makePropstatResponse(href string, pstats []Propstat) *response { + resp := response{ + Href: []string{(&url.URL{Path: href}).EscapedPath()}, + Propstat: make([]propstat, 0, len(pstats)), + } + for _, p := range pstats { + var xmlErr *xmlError + if p.XMLError != "" { + xmlErr = &xmlError{InnerXML: []byte(p.XMLError)} + } + resp.Propstat = append(resp.Propstat, propstat{ + Status: fmt.Sprintf("HTTP/1.1 %d %s", p.Status, StatusText(p.Status)), + Prop: p.Props, + ResponseDescription: p.ResponseDescription, + Error: xmlErr, + }) + } + return &resp +} + +const ( + infiniteDepth = -1 + invalidDepth = -2 +) + +// parseDepth maps the strings "0", "1" and "infinity" to 0, 1 and +// infiniteDepth. Parsing any other string returns invalidDepth. +// +// Different WebDAV methods have further constraints on valid depths: +// - PROPFIND has no further restrictions, as per section 9.1. +// - COPY accepts only "0" or "infinity", as per section 9.8.3. +// - MOVE accepts only "infinity", as per section 9.9.2. +// - LOCK accepts only "0" or "infinity", as per section 9.10.3. +// +// These constraints are enforced by the handleXxx methods. +func parseDepth(s string) int { + switch s { + case "0": + return 0 + case "1": + return 1 + case "infinity": + return infiniteDepth + } + return invalidDepth +} + +// http://www.webdav.org/specs/rfc4918.html#status.code.extensions.to.http11 +const ( + StatusMulti = 207 + StatusUnprocessableEntity = 422 + StatusLocked = 423 + StatusFailedDependency = 424 + StatusInsufficientStorage = 507 +) + +func StatusText(code int) string { + switch code { + case StatusMulti: + return "Multi-Status" + case StatusUnprocessableEntity: + return "Unprocessable Entity" + case StatusLocked: + return "Locked" + case StatusFailedDependency: + return "Failed Dependency" + case StatusInsufficientStorage: + return "Insufficient Storage" + } + return http.StatusText(code) +} + +var ( + errDestinationEqualsSource = errors.New("webdav: destination equals source") + errDirectoryNotEmpty = errors.New("webdav: directory not empty") + errInvalidDepth = errors.New("webdav: invalid depth") + errInvalidDestination = errors.New("webdav: invalid destination") + errInvalidIfHeader = errors.New("webdav: invalid If header") + errInvalidLockInfo = errors.New("webdav: invalid lock info") + errInvalidLockToken = errors.New("webdav: invalid lock token") + errInvalidPropfind = errors.New("webdav: invalid propfind") + errInvalidProppatch = errors.New("webdav: invalid proppatch") + errInvalidResponse = errors.New("webdav: invalid response") + errInvalidTimeout = errors.New("webdav: invalid timeout") + errNoFileSystem = errors.New("webdav: no file system") + errNoLockSystem = errors.New("webdav: no lock system") + errNotADirectory = errors.New("webdav: not a directory") + errPrefixMismatch = errors.New("webdav: prefix mismatch") + errRecursionTooDeep = errors.New("webdav: recursion too deep") + errUnsupportedLockInfo = errors.New("webdav: unsupported lock info") + errUnsupportedMethod = errors.New("webdav: unsupported method") +) diff --git a/server/webdav/xml.go b/server/webdav/xml.go new file mode 100644 index 0000000000000000000000000000000000000000..341f4e46694beebee2350943d72d257abec2b25d --- /dev/null +++ b/server/webdav/xml.go @@ -0,0 +1,519 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package webdav + +// The XML encoding is covered by Section 14. +// http://www.webdav.org/specs/rfc4918.html#xml.element.definitions + +import ( + "bytes" + "encoding/xml" + "fmt" + "io" + "net/http" + "time" + + // As of https://go-review.googlesource.com/#/c/12772/ which was submitted + // in July 2015, this package uses an internal fork of the standard + // library's encoding/xml package, due to changes in the way namespaces + // were encoded. Such changes were introduced in the Go 1.5 cycle, but were + // rolled back in response to https://github.com/golang/go/issues/11841 + // + // However, this package's exported API, specifically the Property and + // DeadPropsHolder types, need to refer to the standard library's version + // of the xml.Name type, as code that imports this package cannot refer to + // the internal version. + // + // This file therefore imports both the internal and external versions, as + // ixml and xml, and converts between them. + // + // In the long term, this package should use the standard library's version + // only, and the internal fork deleted, once + // https://github.com/golang/go/issues/13400 is resolved. + ixml "github.com/alist-org/alist/v3/server/webdav/internal/xml" +) + +// http://www.webdav.org/specs/rfc4918.html#ELEMENT_lockinfo +type lockInfo struct { + XMLName ixml.Name `xml:"lockinfo"` + Exclusive *struct{} `xml:"lockscope>exclusive"` + Shared *struct{} `xml:"lockscope>shared"` + Write *struct{} `xml:"locktype>write"` + Owner owner `xml:"owner"` +} + +// http://www.webdav.org/specs/rfc4918.html#ELEMENT_owner +type owner struct { + InnerXML string `xml:",innerxml"` +} + +func readLockInfo(r io.Reader) (li lockInfo, status int, err error) { + c := &countingReader{r: r} + if err = ixml.NewDecoder(c).Decode(&li); err != nil { + if err == io.EOF { + if c.n == 0 { + // An empty body means to refresh the lock. + // http://www.webdav.org/specs/rfc4918.html#refreshing-locks + return lockInfo{}, 0, nil + } + err = errInvalidLockInfo + } + return lockInfo{}, http.StatusBadRequest, err + } + // We only support exclusive (non-shared) write locks. In practice, these are + // the only types of locks that seem to matter. + if li.Exclusive == nil || li.Shared != nil || li.Write == nil { + return lockInfo{}, http.StatusNotImplemented, errUnsupportedLockInfo + } + return li, 0, nil +} + +type countingReader struct { + n int + r io.Reader +} + +func (c *countingReader) Read(p []byte) (int, error) { + n, err := c.r.Read(p) + c.n += n + return n, err +} + +func writeLockInfo(w io.Writer, token string, ld LockDetails) (int, error) { + depth := "infinity" + if ld.ZeroDepth { + depth = "0" + } + timeout := ld.Duration / time.Second + return fmt.Fprintf(w, "\n"+ + "\n"+ + " \n"+ + " \n"+ + " %s\n"+ + " %s\n"+ + " Second-%d\n"+ + " %s\n"+ + " %s\n"+ + "", + depth, ld.OwnerXML, timeout, escape(token), escape(ld.Root), + ) +} + +func escape(s string) string { + for i := 0; i < len(s); i++ { + switch s[i] { + case '"', '&', '\'', '<', '>': + b := bytes.NewBuffer(nil) + ixml.EscapeText(b, []byte(s)) + return b.String() + } + } + return s +} + +// Next returns the next token, if any, in the XML stream of d. +// RFC 4918 requires to ignore comments, processing instructions +// and directives. +// http://www.webdav.org/specs/rfc4918.html#property_values +// http://www.webdav.org/specs/rfc4918.html#xml-extensibility +func next(d *ixml.Decoder) (ixml.Token, error) { + for { + t, err := d.Token() + if err != nil { + return t, err + } + switch t.(type) { + case ixml.Comment, ixml.Directive, ixml.ProcInst: + continue + default: + return t, nil + } + } +} + +// http://www.webdav.org/specs/rfc4918.html#ELEMENT_prop (for propfind) +type propfindProps []xml.Name + +// UnmarshalXML appends the property names enclosed within start to pn. +// +// It returns an error if start does not contain any properties or if +// properties contain values. Character data between properties is ignored. +func (pn *propfindProps) UnmarshalXML(d *ixml.Decoder, start ixml.StartElement) error { + for { + t, err := next(d) + if err != nil { + return err + } + switch t.(type) { + case ixml.EndElement: + if len(*pn) == 0 { + return fmt.Errorf("%s must not be empty", start.Name.Local) + } + return nil + case ixml.StartElement: + name := t.(ixml.StartElement).Name + t, err = next(d) + if err != nil { + return err + } + if _, ok := t.(ixml.EndElement); !ok { + return fmt.Errorf("unexpected token %T", t) + } + *pn = append(*pn, xml.Name(name)) + } + } +} + +// http://www.webdav.org/specs/rfc4918.html#ELEMENT_propfind +type propfind struct { + XMLName ixml.Name `xml:"DAV: propfind"` + Allprop *struct{} `xml:"DAV: allprop"` + Propname *struct{} `xml:"DAV: propname"` + Prop propfindProps `xml:"DAV: prop"` + Include propfindProps `xml:"DAV: include"` +} + +func readPropfind(r io.Reader) (pf propfind, status int, err error) { + c := countingReader{r: r} + if err = ixml.NewDecoder(&c).Decode(&pf); err != nil { + if err == io.EOF { + if c.n == 0 { + // An empty body means to propfind allprop. + // http://www.webdav.org/specs/rfc4918.html#METHOD_PROPFIND + return propfind{Allprop: new(struct{})}, 0, nil + } + err = errInvalidPropfind + } + return propfind{}, http.StatusBadRequest, err + } + + if pf.Allprop == nil && pf.Include != nil { + return propfind{}, http.StatusBadRequest, errInvalidPropfind + } + if pf.Allprop != nil && (pf.Prop != nil || pf.Propname != nil) { + return propfind{}, http.StatusBadRequest, errInvalidPropfind + } + if pf.Prop != nil && pf.Propname != nil { + return propfind{}, http.StatusBadRequest, errInvalidPropfind + } + if pf.Propname == nil && pf.Allprop == nil && pf.Prop == nil { + return propfind{}, http.StatusBadRequest, errInvalidPropfind + } + return pf, 0, nil +} + +// Property represents a single DAV resource property as defined in RFC 4918. +// See http://www.webdav.org/specs/rfc4918.html#data.model.for.resource.properties +type Property struct { + // XMLName is the fully qualified name that identifies this property. + XMLName xml.Name + + // Lang is an optional xml:lang attribute. + Lang string `xml:"xml:lang,attr,omitempty"` + + // InnerXML contains the XML representation of the property value. + // See http://www.webdav.org/specs/rfc4918.html#property_values + // + // Property values of complex type or mixed-content must have fully + // expanded XML namespaces or be self-contained with according + // XML namespace declarations. They must not rely on any XML + // namespace declarations within the scope of the XML document, + // even including the DAV: namespace. + InnerXML []byte `xml:",innerxml"` +} + +// ixmlProperty is the same as the Property type except it holds an ixml.Name +// instead of an xml.Name. +type ixmlProperty struct { + XMLName ixml.Name + Lang string `xml:"xml:lang,attr,omitempty"` + InnerXML []byte `xml:",innerxml"` +} + +// http://www.webdav.org/specs/rfc4918.html#ELEMENT_error +// See multistatusWriter for the "D:" namespace prefix. +type xmlError struct { + XMLName ixml.Name `xml:"D:error"` + InnerXML []byte `xml:",innerxml"` +} + +// http://www.webdav.org/specs/rfc4918.html#ELEMENT_propstat +// See multistatusWriter for the "D:" namespace prefix. +type propstat struct { + Prop []Property `xml:"D:prop>_ignored_"` + Status string `xml:"D:status"` + Error *xmlError `xml:"D:error"` + ResponseDescription string `xml:"D:responsedescription,omitempty"` +} + +// ixmlPropstat is the same as the propstat type except it holds an ixml.Name +// instead of an xml.Name. +type ixmlPropstat struct { + Prop []ixmlProperty `xml:"D:prop>_ignored_"` + Status string `xml:"D:status"` + Error *xmlError `xml:"D:error"` + ResponseDescription string `xml:"D:responsedescription,omitempty"` +} + +// MarshalXML prepends the "D:" namespace prefix on properties in the DAV: namespace +// before encoding. See multistatusWriter. +func (ps propstat) MarshalXML(e *ixml.Encoder, start ixml.StartElement) error { + // Convert from a propstat to an ixmlPropstat. + ixmlPs := ixmlPropstat{ + Prop: make([]ixmlProperty, len(ps.Prop)), + Status: ps.Status, + Error: ps.Error, + ResponseDescription: ps.ResponseDescription, + } + for k, prop := range ps.Prop { + ixmlPs.Prop[k] = ixmlProperty{ + XMLName: ixml.Name(prop.XMLName), + Lang: prop.Lang, + InnerXML: prop.InnerXML, + } + } + + for k, prop := range ixmlPs.Prop { + if prop.XMLName.Space == "DAV:" { + prop.XMLName = ixml.Name{Space: "", Local: "D:" + prop.XMLName.Local} + ixmlPs.Prop[k] = prop + } + } + // Distinct type to avoid infinite recursion of MarshalXML. + type newpropstat ixmlPropstat + return e.EncodeElement(newpropstat(ixmlPs), start) +} + +// http://www.webdav.org/specs/rfc4918.html#ELEMENT_response +// See multistatusWriter for the "D:" namespace prefix. +type response struct { + XMLName ixml.Name `xml:"D:response"` + Href []string `xml:"D:href"` + Propstat []propstat `xml:"D:propstat"` + Status string `xml:"D:status,omitempty"` + Error *xmlError `xml:"D:error"` + ResponseDescription string `xml:"D:responsedescription,omitempty"` +} + +// MultistatusWriter marshals one or more Responses into a XML +// multistatus response. +// See http://www.webdav.org/specs/rfc4918.html#ELEMENT_multistatus +// TODO(rsto, mpl): As a workaround, the "D:" namespace prefix, defined as +// "DAV:" on this element, is prepended on the nested response, as well as on all +// its nested elements. All property names in the DAV: namespace are prefixed as +// well. This is because some versions of Mini-Redirector (on windows 7) ignore +// elements with a default namespace (no prefixed namespace). A less intrusive fix +// should be possible after golang.org/cl/11074. See https://golang.org/issue/11177 +type multistatusWriter struct { + // ResponseDescription contains the optional responsedescription + // of the multistatus XML element. Only the latest content before + // close will be emitted. Empty response descriptions are not + // written. + responseDescription string + + w http.ResponseWriter + enc *ixml.Encoder +} + +// Write validates and emits a DAV response as part of a multistatus response +// element. +// +// It sets the HTTP status code of its underlying http.ResponseWriter to 207 +// (Multi-Status) and populates the Content-Type header. If r is the +// first, valid response to be written, Write prepends the XML representation +// of r with a multistatus tag. Callers must call close after the last response +// has been written. +func (w *multistatusWriter) write(r *response) error { + switch len(r.Href) { + case 0: + return errInvalidResponse + case 1: + if len(r.Propstat) > 0 != (r.Status == "") { + return errInvalidResponse + } + default: + if len(r.Propstat) > 0 || r.Status == "" { + return errInvalidResponse + } + } + err := w.writeHeader() + if err != nil { + return err + } + return w.enc.Encode(r) +} + +// writeHeader writes a XML multistatus start element on w's underlying +// http.ResponseWriter and returns the result of the write operation. +// After the first write attempt, writeHeader becomes a no-op. +func (w *multistatusWriter) writeHeader() error { + if w.enc != nil { + return nil + } + w.w.Header().Add("Content-Type", "text/xml; charset=utf-8") + w.w.WriteHeader(StatusMulti) + _, err := fmt.Fprintf(w.w, ``) + if err != nil { + return err + } + w.enc = ixml.NewEncoder(w.w) + return w.enc.EncodeToken(ixml.StartElement{ + Name: ixml.Name{ + Space: "DAV:", + Local: "multistatus", + }, + Attr: []ixml.Attr{{ + Name: ixml.Name{Space: "xmlns", Local: "D"}, + Value: "DAV:", + }}, + }) +} + +// Close completes the marshalling of the multistatus response. It returns +// an error if the multistatus response could not be completed. If both the +// return value and field enc of w are nil, then no multistatus response has +// been written. +func (w *multistatusWriter) close() error { + if w.enc == nil { + return nil + } + var end []ixml.Token + if w.responseDescription != "" { + name := ixml.Name{Space: "DAV:", Local: "responsedescription"} + end = append(end, + ixml.StartElement{Name: name}, + ixml.CharData(w.responseDescription), + ixml.EndElement{Name: name}, + ) + } + end = append(end, ixml.EndElement{ + Name: ixml.Name{Space: "DAV:", Local: "multistatus"}, + }) + for _, t := range end { + err := w.enc.EncodeToken(t) + if err != nil { + return err + } + } + return w.enc.Flush() +} + +var xmlLangName = ixml.Name{Space: "http://www.w3.org/XML/1998/namespace", Local: "lang"} + +func xmlLang(s ixml.StartElement, d string) string { + for _, attr := range s.Attr { + if attr.Name == xmlLangName { + return attr.Value + } + } + return d +} + +type xmlValue []byte + +func (v *xmlValue) UnmarshalXML(d *ixml.Decoder, start ixml.StartElement) error { + // The XML value of a property can be arbitrary, mixed-content XML. + // To make sure that the unmarshalled value contains all required + // namespaces, we encode all the property value XML tokens into a + // buffer. This forces the encoder to redeclare any used namespaces. + var b bytes.Buffer + e := ixml.NewEncoder(&b) + for { + t, err := next(d) + if err != nil { + return err + } + if e, ok := t.(ixml.EndElement); ok && e.Name == start.Name { + break + } + if err = e.EncodeToken(t); err != nil { + return err + } + } + err := e.Flush() + if err != nil { + return err + } + *v = b.Bytes() + return nil +} + +// http://www.webdav.org/specs/rfc4918.html#ELEMENT_prop (for proppatch) +type proppatchProps []Property + +// UnmarshalXML appends the property names and values enclosed within start +// to ps. +// +// An xml:lang attribute that is defined either on the DAV:prop or property +// name XML element is propagated to the property's Lang field. +// +// UnmarshalXML returns an error if start does not contain any properties or if +// property values contain syntactically incorrect XML. +func (ps *proppatchProps) UnmarshalXML(d *ixml.Decoder, start ixml.StartElement) error { + lang := xmlLang(start, "") + for { + t, err := next(d) + if err != nil { + return err + } + switch elem := t.(type) { + case ixml.EndElement: + if len(*ps) == 0 { + return fmt.Errorf("%s must not be empty", start.Name.Local) + } + return nil + case ixml.StartElement: + p := Property{ + XMLName: xml.Name(t.(ixml.StartElement).Name), + Lang: xmlLang(t.(ixml.StartElement), lang), + } + err = d.DecodeElement(((*xmlValue)(&p.InnerXML)), &elem) + if err != nil { + return err + } + *ps = append(*ps, p) + } + } +} + +// http://www.webdav.org/specs/rfc4918.html#ELEMENT_set +// http://www.webdav.org/specs/rfc4918.html#ELEMENT_remove +type setRemove struct { + XMLName ixml.Name + Lang string `xml:"xml:lang,attr,omitempty"` + Prop proppatchProps `xml:"DAV: prop"` +} + +// http://www.webdav.org/specs/rfc4918.html#ELEMENT_propertyupdate +type propertyupdate struct { + XMLName ixml.Name `xml:"DAV: propertyupdate"` + Lang string `xml:"xml:lang,attr,omitempty"` + SetRemove []setRemove `xml:",any"` +} + +func readProppatch(r io.Reader) (patches []Proppatch, status int, err error) { + var pu propertyupdate + if err = ixml.NewDecoder(r).Decode(&pu); err != nil { + return nil, http.StatusBadRequest, err + } + for _, op := range pu.SetRemove { + remove := false + switch op.XMLName { + case ixml.Name{Space: "DAV:", Local: "set"}: + // No-op. + case ixml.Name{Space: "DAV:", Local: "remove"}: + for _, p := range op.Prop { + if len(p.InnerXML) > 0 { + return nil, http.StatusBadRequest, errInvalidProppatch + } + } + remove = true + default: + return nil, http.StatusBadRequest, errInvalidProppatch + } + patches = append(patches, Proppatch{Remove: remove, Props: op.Prop}) + } + return patches, 0, nil +} diff --git a/server/webdav/xml_test.go b/server/webdav/xml_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0b05d5d8c990b54dc550c45a969136fb3cdc7d2b --- /dev/null +++ b/server/webdav/xml_test.go @@ -0,0 +1,905 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package webdav + +import ( + "bytes" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/http/httptest" + "reflect" + "sort" + "strings" + "testing" + + ixml "github.com/alist-org/alist/v3/server/webdav/internal/xml" +) + +func TestReadLockInfo(t *testing.T) { + // The "section x.y.z" test cases come from section x.y.z of the spec at + // http://www.webdav.org/specs/rfc4918.html + testCases := []struct { + desc string + input string + wantLI lockInfo + wantStatus int + }{{ + "bad: junk", + "xxx", + lockInfo{}, + http.StatusBadRequest, + }, { + "bad: invalid owner XML", + "" + + "\n" + + " \n" + + " \n" + + " \n" + + " no end tag \n" + + " \n" + + "", + lockInfo{}, + http.StatusBadRequest, + }, { + "bad: invalid UTF-8", + "" + + "\n" + + " \n" + + " \n" + + " \n" + + " \xff \n" + + " \n" + + "", + lockInfo{}, + http.StatusBadRequest, + }, { + "bad: unfinished XML #1", + "" + + "\n" + + " \n" + + " \n", + lockInfo{}, + http.StatusBadRequest, + }, { + "bad: unfinished XML #2", + "" + + "\n" + + " \n" + + " \n" + + " \n", + lockInfo{}, + http.StatusBadRequest, + }, { + "good: empty", + "", + lockInfo{}, + 0, + }, { + "good: plain-text owner", + "" + + "\n" + + " \n" + + " \n" + + " gopher\n" + + "", + lockInfo{ + XMLName: ixml.Name{Space: "DAV:", Local: "lockinfo"}, + Exclusive: new(struct{}), + Write: new(struct{}), + Owner: owner{ + InnerXML: "gopher", + }, + }, + 0, + }, { + "section 9.10.7", + "" + + "\n" + + " \n" + + " \n" + + " \n" + + " http://example.org/~ejw/contact.html\n" + + " \n" + + "", + lockInfo{ + XMLName: ixml.Name{Space: "DAV:", Local: "lockinfo"}, + Exclusive: new(struct{}), + Write: new(struct{}), + Owner: owner{ + InnerXML: "\n http://example.org/~ejw/contact.html\n ", + }, + }, + 0, + }} + + for _, tc := range testCases { + li, status, err := readLockInfo(strings.NewReader(tc.input)) + if tc.wantStatus != 0 { + if err == nil { + t.Errorf("%s: got nil error, want non-nil", tc.desc) + continue + } + } else if err != nil { + t.Errorf("%s: %v", tc.desc, err) + continue + } + if !reflect.DeepEqual(li, tc.wantLI) || status != tc.wantStatus { + t.Errorf("%s:\ngot lockInfo=%v, status=%v\nwant lockInfo=%v, status=%v", + tc.desc, li, status, tc.wantLI, tc.wantStatus) + continue + } + } +} + +func TestReadPropfind(t *testing.T) { + testCases := []struct { + desc string + input string + wantPF propfind + wantStatus int + }{{ + desc: "propfind: propname", + input: "" + + "\n" + + " \n" + + "", + wantPF: propfind{ + XMLName: ixml.Name{Space: "DAV:", Local: "propfind"}, + Propname: new(struct{}), + }, + }, { + desc: "propfind: empty body means allprop", + input: "", + wantPF: propfind{ + Allprop: new(struct{}), + }, + }, { + desc: "propfind: allprop", + input: "" + + "\n" + + " \n" + + "", + wantPF: propfind{ + XMLName: ixml.Name{Space: "DAV:", Local: "propfind"}, + Allprop: new(struct{}), + }, + }, { + desc: "propfind: allprop followed by include", + input: "" + + "\n" + + " \n" + + " \n" + + "", + wantPF: propfind{ + XMLName: ixml.Name{Space: "DAV:", Local: "propfind"}, + Allprop: new(struct{}), + Include: propfindProps{xml.Name{Space: "DAV:", Local: "displayname"}}, + }, + }, { + desc: "propfind: include followed by allprop", + input: "" + + "\n" + + " \n" + + " \n" + + "", + wantPF: propfind{ + XMLName: ixml.Name{Space: "DAV:", Local: "propfind"}, + Allprop: new(struct{}), + Include: propfindProps{xml.Name{Space: "DAV:", Local: "displayname"}}, + }, + }, { + desc: "propfind: propfind", + input: "" + + "\n" + + " \n" + + "", + wantPF: propfind{ + XMLName: ixml.Name{Space: "DAV:", Local: "propfind"}, + Prop: propfindProps{xml.Name{Space: "DAV:", Local: "displayname"}}, + }, + }, { + desc: "propfind: prop with ignored comments", + input: "" + + "\n" + + " \n" + + " \n" + + " \n" + + " \n" + + "", + wantPF: propfind{ + XMLName: ixml.Name{Space: "DAV:", Local: "propfind"}, + Prop: propfindProps{xml.Name{Space: "DAV:", Local: "displayname"}}, + }, + }, { + desc: "propfind: propfind with ignored whitespace", + input: "" + + "\n" + + " \n" + + "", + wantPF: propfind{ + XMLName: ixml.Name{Space: "DAV:", Local: "propfind"}, + Prop: propfindProps{xml.Name{Space: "DAV:", Local: "displayname"}}, + }, + }, { + desc: "propfind: propfind with ignored mixed-content", + input: "" + + "\n" + + " foobar\n" + + "", + wantPF: propfind{ + XMLName: ixml.Name{Space: "DAV:", Local: "propfind"}, + Prop: propfindProps{xml.Name{Space: "DAV:", Local: "displayname"}}, + }, + }, { + desc: "propfind: propname with ignored element (section A.4)", + input: "" + + "\n" + + " \n" + + " *boss*\n" + + "", + wantPF: propfind{ + XMLName: ixml.Name{Space: "DAV:", Local: "propfind"}, + Propname: new(struct{}), + }, + }, { + desc: "propfind: bad: junk", + input: "xxx", + wantStatus: http.StatusBadRequest, + }, { + desc: "propfind: bad: propname and allprop (section A.3)", + input: "" + + "\n" + + " " + + " " + + "", + wantStatus: http.StatusBadRequest, + }, { + desc: "propfind: bad: propname and prop", + input: "" + + "\n" + + " \n" + + " \n" + + "", + wantStatus: http.StatusBadRequest, + }, { + desc: "propfind: bad: allprop and prop", + input: "" + + "\n" + + " \n" + + " \n" + + "", + wantStatus: http.StatusBadRequest, + }, { + desc: "propfind: bad: empty propfind with ignored element (section A.4)", + input: "" + + "\n" + + " \n" + + "", + wantStatus: http.StatusBadRequest, + }, { + desc: "propfind: bad: empty prop", + input: "" + + "\n" + + " \n" + + "", + wantStatus: http.StatusBadRequest, + }, { + desc: "propfind: bad: prop with just chardata", + input: "" + + "\n" + + " foo\n" + + "", + wantStatus: http.StatusBadRequest, + }, { + desc: "bad: interrupted prop", + input: "" + + "\n" + + " \n", + wantStatus: http.StatusBadRequest, + }, { + desc: "bad: malformed end element prop", + input: "" + + "\n" + + " \n", + wantStatus: http.StatusBadRequest, + }, { + desc: "propfind: bad: property with chardata value", + input: "" + + "\n" + + " bar\n" + + "", + wantStatus: http.StatusBadRequest, + }, { + desc: "propfind: bad: property with whitespace value", + input: "" + + "\n" + + " \n" + + "", + wantStatus: http.StatusBadRequest, + }, { + desc: "propfind: bad: include without allprop", + input: "" + + "\n" + + " \n" + + "", + wantStatus: http.StatusBadRequest, + }} + + for _, tc := range testCases { + pf, status, err := readPropfind(strings.NewReader(tc.input)) + if tc.wantStatus != 0 { + if err == nil { + t.Errorf("%s: got nil error, want non-nil", tc.desc) + continue + } + } else if err != nil { + t.Errorf("%s: %v", tc.desc, err) + continue + } + if !reflect.DeepEqual(pf, tc.wantPF) || status != tc.wantStatus { + t.Errorf("%s:\ngot propfind=%v, status=%v\nwant propfind=%v, status=%v", + tc.desc, pf, status, tc.wantPF, tc.wantStatus) + continue + } + } +} + +func TestMultistatusWriter(t *testing.T) { + ///The "section x.y.z" test cases come from section x.y.z of the spec at + // http://www.webdav.org/specs/rfc4918.html + testCases := []struct { + desc string + responses []response + respdesc string + writeHeader bool + wantXML string + wantCode int + wantErr error + }{{ + desc: "section 9.2.2 (failed dependency)", + responses: []response{{ + Href: []string{"http://example.com/foo"}, + Propstat: []propstat{{ + Prop: []Property{{ + XMLName: xml.Name{ + Space: "http://ns.example.com/", + Local: "Authors", + }, + }}, + Status: "HTTP/1.1 424 Failed Dependency", + }, { + Prop: []Property{{ + XMLName: xml.Name{ + Space: "http://ns.example.com/", + Local: "Copyright-Owner", + }, + }}, + Status: "HTTP/1.1 409 Conflict", + }}, + ResponseDescription: "Copyright Owner cannot be deleted or altered.", + }}, + wantXML: `` + + `` + + `` + + ` ` + + ` http://example.com/foo` + + ` ` + + ` ` + + ` ` + + ` ` + + ` HTTP/1.1 424 Failed Dependency` + + ` ` + + ` ` + + ` ` + + ` ` + + ` ` + + ` HTTP/1.1 409 Conflict` + + ` ` + + ` Copyright Owner cannot be deleted or altered.` + + `` + + ``, + wantCode: StatusMulti, + }, { + desc: "section 9.6.2 (lock-token-submitted)", + responses: []response{{ + Href: []string{"http://example.com/foo"}, + Status: "HTTP/1.1 423 Locked", + Error: &xmlError{ + InnerXML: []byte(``), + }, + }}, + wantXML: `` + + `` + + `` + + ` ` + + ` http://example.com/foo` + + ` HTTP/1.1 423 Locked` + + ` ` + + ` ` + + ``, + wantCode: StatusMulti, + }, { + desc: "section 9.1.3", + responses: []response{{ + Href: []string{"http://example.com/foo"}, + Propstat: []propstat{{ + Prop: []Property{{ + XMLName: xml.Name{Space: "http://ns.example.com/boxschema/", Local: "bigbox"}, + InnerXML: []byte(`` + + `` + + `Box type A` + + ``), + }, { + XMLName: xml.Name{Space: "http://ns.example.com/boxschema/", Local: "author"}, + InnerXML: []byte(`` + + `` + + `J.J. Johnson` + + ``), + }}, + Status: "HTTP/1.1 200 OK", + }, { + Prop: []Property{{ + XMLName: xml.Name{Space: "http://ns.example.com/boxschema/", Local: "DingALing"}, + }, { + XMLName: xml.Name{Space: "http://ns.example.com/boxschema/", Local: "Random"}, + }}, + Status: "HTTP/1.1 403 Forbidden", + ResponseDescription: "The user does not have access to the DingALing property.", + }}, + }}, + respdesc: "There has been an access violation error.", + wantXML: `` + + `` + + `` + + ` ` + + ` http://example.com/foo` + + ` ` + + ` ` + + ` Box type A` + + ` J.J. Johnson` + + ` ` + + ` HTTP/1.1 200 OK` + + ` ` + + ` ` + + ` ` + + ` ` + + ` ` + + ` ` + + ` HTTP/1.1 403 Forbidden` + + ` The user does not have access to the DingALing property.` + + ` ` + + ` ` + + ` There has been an access violation error.` + + ``, + wantCode: StatusMulti, + }, { + desc: "no response written", + // default of http.responseWriter + wantCode: http.StatusOK, + }, { + desc: "no response written (with description)", + respdesc: "too bad", + // default of http.responseWriter + wantCode: http.StatusOK, + }, { + desc: "empty multistatus with header", + writeHeader: true, + wantXML: ``, + wantCode: StatusMulti, + }, { + desc: "bad: no href", + responses: []response{{ + Propstat: []propstat{{ + Prop: []Property{{ + XMLName: xml.Name{ + Space: "http://example.com/", + Local: "foo", + }, + }}, + Status: "HTTP/1.1 200 OK", + }}, + }}, + wantErr: errInvalidResponse, + // default of http.responseWriter + wantCode: http.StatusOK, + }, { + desc: "bad: multiple hrefs and no status", + responses: []response{{ + Href: []string{"http://example.com/foo", "http://example.com/bar"}, + }}, + wantErr: errInvalidResponse, + // default of http.responseWriter + wantCode: http.StatusOK, + }, { + desc: "bad: one href and no propstat", + responses: []response{{ + Href: []string{"http://example.com/foo"}, + }}, + wantErr: errInvalidResponse, + // default of http.responseWriter + wantCode: http.StatusOK, + }, { + desc: "bad: status with one href and propstat", + responses: []response{{ + Href: []string{"http://example.com/foo"}, + Propstat: []propstat{{ + Prop: []Property{{ + XMLName: xml.Name{ + Space: "http://example.com/", + Local: "foo", + }, + }}, + Status: "HTTP/1.1 200 OK", + }}, + Status: "HTTP/1.1 200 OK", + }}, + wantErr: errInvalidResponse, + // default of http.responseWriter + wantCode: http.StatusOK, + }, { + desc: "bad: multiple hrefs and propstat", + responses: []response{{ + Href: []string{ + "http://example.com/foo", + "http://example.com/bar", + }, + Propstat: []propstat{{ + Prop: []Property{{ + XMLName: xml.Name{ + Space: "http://example.com/", + Local: "foo", + }, + }}, + Status: "HTTP/1.1 200 OK", + }}, + }}, + wantErr: errInvalidResponse, + // default of http.responseWriter + wantCode: http.StatusOK, + }} + + n := xmlNormalizer{omitWhitespace: true} +loop: + for _, tc := range testCases { + rec := httptest.NewRecorder() + w := multistatusWriter{w: rec, responseDescription: tc.respdesc} + if tc.writeHeader { + if err := w.writeHeader(); err != nil { + t.Errorf("%s: got writeHeader error %v, want nil", tc.desc, err) + continue + } + } + for _, r := range tc.responses { + if err := w.write(&r); err != nil { + if err != tc.wantErr { + t.Errorf("%s: got write error %v, want %v", + tc.desc, err, tc.wantErr) + } + continue loop + } + } + if err := w.close(); err != tc.wantErr { + t.Errorf("%s: got close error %v, want %v", + tc.desc, err, tc.wantErr) + continue + } + if rec.Code != tc.wantCode { + t.Errorf("%s: got HTTP status code %d, want %d\n", + tc.desc, rec.Code, tc.wantCode) + continue + } + gotXML := rec.Body.String() + eq, err := n.equalXML(strings.NewReader(gotXML), strings.NewReader(tc.wantXML)) + if err != nil { + t.Errorf("%s: equalXML: %v", tc.desc, err) + continue + } + if !eq { + t.Errorf("%s: XML body\ngot %s\nwant %s", tc.desc, gotXML, tc.wantXML) + } + } +} + +func TestReadProppatch(t *testing.T) { + ppStr := func(pps []Proppatch) string { + var outer []string + for _, pp := range pps { + var inner []string + for _, p := range pp.Props { + inner = append(inner, fmt.Sprintf("{XMLName: %q, Lang: %q, InnerXML: %q}", + p.XMLName, p.Lang, p.InnerXML)) + } + outer = append(outer, fmt.Sprintf("{Remove: %t, Props: [%s]}", + pp.Remove, strings.Join(inner, ", "))) + } + return "[" + strings.Join(outer, ", ") + "]" + } + + testCases := []struct { + desc string + input string + wantPP []Proppatch + wantStatus int + }{{ + desc: "proppatch: section 9.2 (with simple property value)", + input: `` + + `` + + `` + + ` ` + + ` somevalue` + + ` ` + + ` ` + + ` ` + + ` ` + + ``, + wantPP: []Proppatch{{ + Props: []Property{{ + xml.Name{Space: "http://ns.example.com/z/", Local: "Authors"}, + "", + []byte(`somevalue`), + }}, + }, { + Remove: true, + Props: []Property{{ + xml.Name{Space: "http://ns.example.com/z/", Local: "Copyright-Owner"}, + "", + nil, + }}, + }}, + }, { + desc: "proppatch: lang attribute on prop", + input: `` + + `` + + `` + + ` ` + + ` ` + + ` ` + + ` ` + + ` ` + + ``, + wantPP: []Proppatch{{ + Props: []Property{{ + xml.Name{Space: "http://example.com/ns", Local: "foo"}, + "en", + nil, + }}, + }}, + }, { + desc: "bad: remove with value", + input: `` + + `` + + `` + + ` ` + + ` ` + + ` ` + + ` Jim Whitehead` + + ` ` + + ` ` + + ` ` + + ``, + wantStatus: http.StatusBadRequest, + }, { + desc: "bad: empty propertyupdate", + input: `` + + `` + + ``, + wantStatus: http.StatusBadRequest, + }, { + desc: "bad: empty prop", + input: `` + + `` + + `` + + ` ` + + ` ` + + ` ` + + ``, + wantStatus: http.StatusBadRequest, + }} + + for _, tc := range testCases { + pp, status, err := readProppatch(strings.NewReader(tc.input)) + if tc.wantStatus != 0 { + if err == nil { + t.Errorf("%s: got nil error, want non-nil", tc.desc) + continue + } + } else if err != nil { + t.Errorf("%s: %v", tc.desc, err) + continue + } + if status != tc.wantStatus { + t.Errorf("%s: got status %d, want %d", tc.desc, status, tc.wantStatus) + continue + } + if !reflect.DeepEqual(pp, tc.wantPP) || status != tc.wantStatus { + t.Errorf("%s: proppatch\ngot %v\nwant %v", tc.desc, ppStr(pp), ppStr(tc.wantPP)) + } + } +} + +func TestUnmarshalXMLValue(t *testing.T) { + testCases := []struct { + desc string + input string + wantVal string + }{{ + desc: "simple char data", + input: "foo", + wantVal: "foo", + }, { + desc: "empty element", + input: "", + wantVal: "", + }, { + desc: "preserve namespace", + input: ``, + wantVal: ``, + }, { + desc: "preserve root element namespace", + input: ``, + wantVal: ``, + }, { + desc: "preserve whitespace", + input: " \t ", + wantVal: " \t ", + }, { + desc: "preserve mixed content", + input: ` a `, + wantVal: ` a `, + }, { + desc: "section 9.2", + input: `` + + `` + + ` Jim Whitehead` + + ` Roy Fielding` + + ``, + wantVal: `` + + ` Jim Whitehead` + + ` Roy Fielding`, + }, { + desc: "section 4.3.1 (mixed content)", + input: `` + + `` + + ` Jane Doe` + + ` ` + + ` mailto:jane.doe@example.com` + + ` http://www.example.com` + + ` ` + + ` Jane has been working way too long on the` + + ` long-awaited revision of ]]>.` + + ` ` + + ``, + wantVal: `` + + ` Jane Doe` + + ` ` + + ` mailto:jane.doe@example.com` + + ` http://www.example.com` + + ` ` + + ` Jane has been working way too long on the` + + ` long-awaited revision of <RFC2518>.` + + ` `, + }} + + var n xmlNormalizer + for _, tc := range testCases { + d := ixml.NewDecoder(strings.NewReader(tc.input)) + var v xmlValue + if err := d.Decode(&v); err != nil { + t.Errorf("%s: got error %v, want nil", tc.desc, err) + continue + } + eq, err := n.equalXML(bytes.NewReader(v), strings.NewReader(tc.wantVal)) + if err != nil { + t.Errorf("%s: equalXML: %v", tc.desc, err) + continue + } + if !eq { + t.Errorf("%s:\ngot %s\nwant %s", tc.desc, string(v), tc.wantVal) + } + } +} + +// xmlNormalizer normalizes XML. +type xmlNormalizer struct { + // omitWhitespace instructs to ignore whitespace between element tags. + omitWhitespace bool + // omitComments instructs to ignore XML comments. + omitComments bool +} + +// normalize writes the normalized XML content of r to w. It applies the +// following rules +// +// - Rename namespace prefixes according to an internal heuristic. +// - Remove unnecessary namespace declarations. +// - Sort attributes in XML start elements in lexical order of their +// fully qualified name. +// - Remove XML directives and processing instructions. +// - Remove CDATA between XML tags that only contains whitespace, if +// instructed to do so. +// - Remove comments, if instructed to do so. +func (n *xmlNormalizer) normalize(w io.Writer, r io.Reader) error { + d := ixml.NewDecoder(r) + e := ixml.NewEncoder(w) + for { + t, err := d.Token() + if err != nil { + if t == nil && err == io.EOF { + break + } + return err + } + switch val := t.(type) { + case ixml.Directive, ixml.ProcInst: + continue + case ixml.Comment: + if n.omitComments { + continue + } + case ixml.CharData: + if n.omitWhitespace && len(bytes.TrimSpace(val)) == 0 { + continue + } + case ixml.StartElement: + start, _ := ixml.CopyToken(val).(ixml.StartElement) + attr := start.Attr[:0] + for _, a := range start.Attr { + if a.Name.Space == "xmlns" || a.Name.Local == "xmlns" { + continue + } + attr = append(attr, a) + } + sort.Sort(byName(attr)) + start.Attr = attr + t = start + } + err = e.EncodeToken(t) + if err != nil { + return err + } + } + return e.Flush() +} + +// equalXML tests for equality of the normalized XML contents of a and b. +func (n *xmlNormalizer) equalXML(a, b io.Reader) (bool, error) { + var buf bytes.Buffer + if err := n.normalize(&buf, a); err != nil { + return false, err + } + normA := buf.String() + buf.Reset() + if err := n.normalize(&buf, b); err != nil { + return false, err + } + normB := buf.String() + return normA == normB, nil +} + +type byName []ixml.Attr + +func (a byName) Len() int { return len(a) } +func (a byName) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a byName) Less(i, j int) bool { + if a[i].Name.Space != a[j].Name.Space { + return a[i].Name.Space < a[j].Name.Space + } + return a[i].Name.Local < a[j].Name.Local +} diff --git a/wrapper/zcc-arm64 b/wrapper/zcc-arm64 new file mode 100644 index 0000000000000000000000000000000000000000..afcd6d431db03af9926fcdad756191f38cd9a503 --- /dev/null +++ b/wrapper/zcc-arm64 @@ -0,0 +1,2 @@ +#!/bin/sh +zig cc -target aarch64-windows-gnu $@ diff --git a/wrapper/zcxx-arm64 b/wrapper/zcxx-arm64 new file mode 100644 index 0000000000000000000000000000000000000000..25c7482fbcb092c7d2c87aa681c51334e7786eb4 --- /dev/null +++ b/wrapper/zcxx-arm64 @@ -0,0 +1,2 @@ +#!/bin/sh +zig c++ -target aarch64-windows-gnu $@