filemanager.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. package filemanager
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/url"
  8. "os"
  9. "path/filepath"
  10. "strings"
  11. "time"
  12. "cloud.google.com/go/storage"
  13. "github.com/Azure/azure-sdk-for-go/sdk/azcore"
  14. "github.com/Azure/azure-sdk-for-go/sdk/azidentity"
  15. "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
  16. "github.com/aws/aws-sdk-go-v2/aws"
  17. "github.com/aws/aws-sdk-go-v2/config"
  18. "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
  19. "github.com/aws/aws-sdk-go-v2/service/s3"
  20. "github.com/aws/aws-sdk-go-v2/service/s3/types"
  21. )
  22. var ErrNotFound = errors.New("not found")
  23. // FileManager is a unified interface for downloading and uploading files from various storage providers.
  24. type FileManager interface {
  25. Download(ctx context.Context, f *os.File) error
  26. Upload(ctx context.Context, f *os.File) error
  27. }
  28. // Examples of valid path:
  29. // - s3://bucket-name/path/to/file.csv
  30. // - gs://bucket-name/path/to/file.csv
  31. // - https://azblobaccount.blob.core.windows.net/containerName/path/to/file.csv
  32. // - local/file/path.csv
  33. func NewFileManager(path string) (FileManager, error) {
  34. switch {
  35. case strings.HasPrefix(path, "s3://"):
  36. return NewS3File(path)
  37. case strings.HasPrefix(path, "gs://"):
  38. return NewGCSStorageFile(path)
  39. case strings.Contains(path, "blob.core.windows.net"):
  40. return NewAzureBlobFile(path)
  41. default:
  42. return NewSystemFile(path), nil
  43. }
  44. }
  45. type AzureBlobFile struct {
  46. client *blockblob.Client
  47. }
  48. func NewAzureBlobFile(blobURL string) (*AzureBlobFile, error) {
  49. credential, err := azidentity.NewDefaultAzureCredential(nil)
  50. if err != nil {
  51. return nil, err
  52. }
  53. client, err := blockblob.NewClient(blobURL, credential, nil)
  54. return &AzureBlobFile{client: client}, err
  55. }
  56. func (a *AzureBlobFile) Download(ctx context.Context, f *os.File) error {
  57. _, err := a.client.DownloadFile(ctx, f, nil)
  58. // Convert Azure error into our own error.
  59. var storageErr *azcore.ResponseError
  60. if errors.As(err, &storageErr) && storageErr.ErrorCode == "BlobNotFound" {
  61. return ErrNotFound
  62. }
  63. return err
  64. }
  65. func (a *AzureBlobFile) Upload(ctx context.Context, f *os.File) error {
  66. _, err := a.client.UploadFile(ctx, f, nil)
  67. return err
  68. }
  69. type S3File struct {
  70. s3Client *s3.Client
  71. bucket string
  72. key string
  73. }
  74. func NewS3File(path string) (*S3File, error) {
  75. u, err := url.Parse(path)
  76. if err != nil {
  77. return nil, err
  78. }
  79. bucket := u.Host
  80. key := strings.TrimPrefix(u.Path, "/")
  81. if bucket == "" || key == "" {
  82. return nil, fmt.Errorf("invalid s3 path: %s", path)
  83. }
  84. cfg, err := config.LoadDefaultConfig(context.Background())
  85. if err != nil {
  86. return nil, err
  87. }
  88. return &S3File{
  89. s3Client: s3.NewFromConfig(cfg),
  90. bucket: bucket,
  91. key: key,
  92. }, nil
  93. }
  94. func (c *S3File) Download(ctx context.Context, f *os.File) error {
  95. _, err := manager.NewDownloader(c.s3Client).Download(ctx, f, &s3.GetObjectInput{
  96. Bucket: aws.String(c.bucket),
  97. Key: aws.String(c.key),
  98. })
  99. // Convert AWS error into our own error type.
  100. var notFound *types.NoSuchKey
  101. if errors.As(err, &notFound) {
  102. return ErrNotFound
  103. }
  104. return err
  105. }
  106. func (c *S3File) Upload(ctx context.Context, f *os.File) error {
  107. _, err := manager.NewUploader(c.s3Client).Upload(ctx, &s3.PutObjectInput{
  108. Bucket: aws.String(c.bucket),
  109. Key: aws.String(c.key),
  110. Body: f,
  111. })
  112. return err
  113. }
  114. type GCSStorageFile struct {
  115. bucket string
  116. key string
  117. client *storage.Client
  118. }
  119. func NewGCSStorageFile(path string) (*GCSStorageFile, error) {
  120. path = strings.TrimPrefix(path, "gs://")
  121. parts := strings.SplitN(path, "/", 2)
  122. if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
  123. return nil, errors.New("invalid GCS path")
  124. }
  125. client, err := storage.NewClient(context.TODO())
  126. if err != nil {
  127. return nil, err
  128. }
  129. return &GCSStorageFile{
  130. client: client,
  131. bucket: parts[0],
  132. key: parts[1],
  133. }, nil
  134. }
  135. func (g *GCSStorageFile) Download(ctx context.Context, f *os.File) error {
  136. r, err := g.client.Bucket(g.bucket).Object(g.key).NewReader(ctx)
  137. if err != nil {
  138. if errors.Is(err, storage.ErrObjectNotExist) {
  139. return ErrNotFound
  140. }
  141. return err
  142. }
  143. defer r.Close()
  144. _, err = io.Copy(f, r)
  145. return err
  146. }
  147. func (g *GCSStorageFile) Upload(ctx context.Context, f *os.File) error {
  148. client, err := storage.NewClient(ctx)
  149. if err != nil {
  150. return err
  151. }
  152. w := client.Bucket(g.bucket).Object(g.key).NewWriter(ctx)
  153. if _, err := io.Copy(w, f); err != nil {
  154. return err
  155. }
  156. return w.Close()
  157. }
  158. func NewSystemFile(path string) *SystemFile {
  159. return &SystemFile{path: path}
  160. }
  161. type SystemFile struct {
  162. path string
  163. }
  164. func (s *SystemFile) Download(ctx context.Context, f *os.File) error {
  165. sFile, err := os.Open(s.path)
  166. if err != nil {
  167. if os.IsNotExist(err) {
  168. return ErrNotFound
  169. }
  170. return err
  171. }
  172. defer sFile.Close()
  173. _, err = io.Copy(f, sFile)
  174. return err
  175. }
  176. func (s *SystemFile) Upload(ctx context.Context, f *os.File) error {
  177. // we want to avoid truncating the file if the upload fails
  178. // so want to write to a temp file and then rename it
  179. // to the final destination
  180. // temp file should be in the same directory as the final destination
  181. // to avoid "invalid cross-device link" errors when attempting to rename the file
  182. _, err := f.Seek(0, io.SeekStart)
  183. if err != nil {
  184. return err
  185. }
  186. tmpFilePath := filepath.Join(filepath.Dir(s.path), fmt.Sprintf(".tmp-%d", time.Now().UnixNano()))
  187. tmpF, err := os.Create(tmpFilePath)
  188. if err != nil {
  189. return err
  190. }
  191. defer os.Remove(tmpF.Name())
  192. defer tmpF.Close()
  193. _, err = io.Copy(tmpF, f)
  194. if err != nil {
  195. return err
  196. }
  197. err = os.Rename(tmpF.Name(), s.path)
  198. if err != nil {
  199. return err
  200. }
  201. return nil
  202. }
  203. type InMemoryFile struct {
  204. Data []byte
  205. }
  206. func (c *InMemoryFile) Download(ctx context.Context, f *os.File) error {
  207. if len(c.Data) == 0 {
  208. return ErrNotFound
  209. }
  210. _, err := f.Write(c.Data)
  211. return err
  212. }
  213. func (c *InMemoryFile) Upload(ctx context.Context, f *os.File) error {
  214. var err error
  215. c.Data, err = io.ReadAll(f)
  216. return err
  217. }