filemanager.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. package filemanager
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/url"
  8. "os"
  9. "path/filepath"
  10. "strings"
  11. "time"
  12. "cloud.google.com/go/storage"
  13. "github.com/Azure/azure-sdk-for-go/sdk/azcore"
  14. "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
  15. "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
  16. "github.com/aws/aws-sdk-go-v2/aws"
  17. "github.com/aws/aws-sdk-go-v2/config"
  18. "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
  19. "github.com/aws/aws-sdk-go-v2/service/s3"
  20. "github.com/aws/aws-sdk-go-v2/service/s3/types"
  21. )
  22. var ErrNotFound = errors.New("not found")
  23. // FileManager is a unified interface for downloading and uploading files from various storage providers.
  24. type FileManager interface {
  25. Download(ctx context.Context, f *os.File) error
  26. Upload(ctx context.Context, f *os.File) error
  27. }
  28. // Examples of valid path:
  29. // - s3://bucket-name/path/to/file.csv
  30. // - gs://bucket-name/path/to/file.csv
  31. // - https://azblobaccount.blob.core.windows.net/containerName/path/to/file.csv
  32. // - local/file/path.csv
  33. func NewFileManager(path string) (FileManager, error) {
  34. switch {
  35. case strings.HasPrefix(path, "s3://"):
  36. return NewS3File(path)
  37. case strings.HasPrefix(path, "gs://"):
  38. return NewGCSStorageFile(path)
  39. case strings.Contains(path, "blob.core.windows.net"):
  40. return NewAzureBlobFile(path)
  41. case path == "":
  42. return nil, errors.New("empty path")
  43. default:
  44. return NewSystemFile(path), nil
  45. }
  46. }
  47. type AzureBlobFile struct {
  48. client *blockblob.Client
  49. }
  50. func NewAzureBlobFile(blobURL string) (*AzureBlobFile, error) {
  51. credential, err := azidentity.NewDefaultAzureCredential(nil)
  52. if err != nil {
  53. return nil, err
  54. }
  55. client, err := blockblob.NewClient(blobURL, credential, nil)
  56. return &AzureBlobFile{client: client}, err
  57. }
  58. func (a *AzureBlobFile) Download(ctx context.Context, f *os.File) error {
  59. _, err := a.client.DownloadFile(ctx, f, nil)
  60. // Convert Azure error into our own error.
  61. var storageErr *azcore.ResponseError
  62. if errors.As(err, &storageErr) && storageErr.ErrorCode == "BlobNotFound" {
  63. return ErrNotFound
  64. }
  65. return err
  66. }
  67. func (a *AzureBlobFile) Upload(ctx context.Context, f *os.File) error {
  68. _, err := a.client.UploadFile(ctx, f, nil)
  69. return err
  70. }
  71. type S3File struct {
  72. s3Client *s3.Client
  73. bucket string
  74. key string
  75. }
  76. func NewS3File(path string) (*S3File, error) {
  77. u, err := url.Parse(path)
  78. if err != nil {
  79. return nil, err
  80. }
  81. bucket := u.Host
  82. key := strings.TrimPrefix(u.Path, "/")
  83. if bucket == "" || key == "" {
  84. return nil, fmt.Errorf("invalid s3 path: %s", path)
  85. }
  86. cfg, err := config.LoadDefaultConfig(context.Background())
  87. if err != nil {
  88. return nil, err
  89. }
  90. return &S3File{
  91. s3Client: s3.NewFromConfig(cfg),
  92. bucket: bucket,
  93. key: key,
  94. }, nil
  95. }
  96. func (c *S3File) Download(ctx context.Context, f *os.File) error {
  97. _, err := manager.NewDownloader(c.s3Client).Download(ctx, f, &s3.GetObjectInput{
  98. Bucket: aws.String(c.bucket),
  99. Key: aws.String(c.key),
  100. })
  101. // Convert AWS error into our own error type.
  102. var notFound *types.NoSuchKey
  103. if errors.As(err, &notFound) {
  104. return ErrNotFound
  105. }
  106. return err
  107. }
  108. func (c *S3File) Upload(ctx context.Context, f *os.File) error {
  109. _, err := manager.NewUploader(c.s3Client).Upload(ctx, &s3.PutObjectInput{
  110. Bucket: aws.String(c.bucket),
  111. Key: aws.String(c.key),
  112. Body: f,
  113. })
  114. return err
  115. }
  116. type GCSStorageFile struct {
  117. bucket string
  118. key string
  119. client *storage.Client
  120. }
  121. func NewGCSStorageFile(path string) (*GCSStorageFile, error) {
  122. path = strings.TrimPrefix(path, "gs://")
  123. parts := strings.SplitN(path, "/", 2)
  124. if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
  125. return nil, errors.New("invalid GCS path")
  126. }
  127. client, err := storage.NewClient(context.TODO())
  128. if err != nil {
  129. return nil, err
  130. }
  131. return &GCSStorageFile{
  132. client: client,
  133. bucket: parts[0],
  134. key: parts[1],
  135. }, nil
  136. }
  137. func (g *GCSStorageFile) Download(ctx context.Context, f *os.File) error {
  138. r, err := g.client.Bucket(g.bucket).Object(g.key).NewReader(ctx)
  139. if err != nil {
  140. if errors.Is(err, storage.ErrObjectNotExist) {
  141. return ErrNotFound
  142. }
  143. return err
  144. }
  145. defer r.Close()
  146. _, err = io.Copy(f, r)
  147. return err
  148. }
  149. func (g *GCSStorageFile) Upload(ctx context.Context, f *os.File) error {
  150. client, err := storage.NewClient(ctx)
  151. if err != nil {
  152. return err
  153. }
  154. w := client.Bucket(g.bucket).Object(g.key).NewWriter(ctx)
  155. if _, err := io.Copy(w, f); err != nil {
  156. return err
  157. }
  158. return w.Close()
  159. }
  160. func NewSystemFile(path string) *SystemFile {
  161. return &SystemFile{path: path}
  162. }
  163. type SystemFile struct {
  164. path string
  165. }
  166. func (s *SystemFile) Download(ctx context.Context, f *os.File) error {
  167. sFile, err := os.Open(s.path)
  168. if err != nil {
  169. if os.IsNotExist(err) {
  170. return ErrNotFound
  171. }
  172. return err
  173. }
  174. defer sFile.Close()
  175. _, err = io.Copy(f, sFile)
  176. return err
  177. }
  178. func (s *SystemFile) Upload(ctx context.Context, f *os.File) error {
  179. // we want to avoid truncating the file if the upload fails
  180. // so want to write to a temp file and then rename it
  181. // to the final destination
  182. // temp file should be in the same directory as the final destination
  183. // to avoid "invalid cross-device link" errors when attempting to rename the file
  184. _, err := f.Seek(0, io.SeekStart)
  185. if err != nil {
  186. return err
  187. }
  188. tmpFilePath := filepath.Join(filepath.Dir(s.path), fmt.Sprintf(".tmp-%d", time.Now().UnixNano()))
  189. tmpF, err := os.Create(tmpFilePath)
  190. if err != nil {
  191. return err
  192. }
  193. defer os.Remove(tmpF.Name())
  194. defer tmpF.Close()
  195. _, err = io.Copy(tmpF, f)
  196. if err != nil {
  197. return err
  198. }
  199. err = os.Rename(tmpF.Name(), s.path)
  200. if err != nil {
  201. return err
  202. }
  203. return nil
  204. }
  205. type InMemoryFile struct {
  206. Data []byte
  207. }
  208. func (c *InMemoryFile) Download(ctx context.Context, f *os.File) error {
  209. if len(c.Data) == 0 {
  210. return ErrNotFound
  211. }
  212. _, err := f.Write(c.Data)
  213. return err
  214. }
  215. func (c *InMemoryFile) Upload(ctx context.Context, f *os.File) error {
  216. var err error
  217. c.Data, err = io.ReadAll(f)
  218. return err
  219. }