From 7c1b08421eac7633e5ae249e4a1d60c2095654d7 Mon Sep 17 00:00:00 2001
From: Bruce Flynn <brucef@ssec.wisc.edu>
Date: Fri, 4 Oct 2019 14:53:30 -0500
Subject: [PATCH] testing

---
 api.go      |  24 +++++---
 api_test.go | 167 +++++++++++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 182 insertions(+), 9 deletions(-)

diff --git a/api.go b/api.go
index 63c8914..0876e46 100644
--- a/api.go
+++ b/api.go
@@ -10,6 +10,7 @@ import (
 	url_ "net/url"
 	"os"
 	"path/filepath"
+	"syscall"
 
 	"github.com/pkg/sftp"
 	"golang.org/x/crypto/ssh"
@@ -38,7 +39,7 @@ type result struct {
 type sftpClient interface {
 	Close() error
 	Create(string) (io.WriteCloser, error)
-	Getwd() (string, error)
+	Poke() bool
 	Open(string) (io.ReadCloser, error)
 	ReadDir(string) ([]os.FileInfo, error)
 	Remove(string) error
@@ -72,8 +73,9 @@ func (c mySFTPClient) Create(path string) (io.WriteCloser, error) {
 	return sftpFile{f}, nil
 }
 
-func (c mySFTPClient) Getwd() (string, error) {
-	return c.SFTP.Getwd()
+func (c mySFTPClient) Poke() bool {
+	_, err := c.SFTP.Getwd()
+	return err == nil
 }
 
 func (c mySFTPClient) Open(path string) (io.ReadCloser, error) {
@@ -184,6 +186,7 @@ func textCommands(src io.Reader) <-chan command {
 	return ch
 }
 
+// generate a path on the remote
 func (s sftpAPI) rempath(path string) string {
 	if !filepath.IsAbs(path) {
 		return filepath.Join(s.URL.Path, path)
@@ -254,12 +257,18 @@ func (s sftpAPI) doListdir(path string) ([]stat, error) {
 	}
 
 	for _, st := range infos {
-		x := st.Sys().(*sftp.FileStat)
+		var m uint32
+		switch s := st.Sys().(type) {
+		case *sftp.FileStat:
+			m = s.Mode
+		case *syscall.Stat_t:
+			m = s.Mode
+		}
 		stats = append(stats, stat{
 			st.Name(),
 			st.Size(),
 			st.ModTime().Unix(),
-			int(x.Mode),
+			int(m),
 		})
 	}
 
@@ -271,9 +280,10 @@ func (s sftpAPI) doDelete(path string) error {
 	return s.Client.Remove(s.rempath(path))
 }
 
+// make sure the client is connected by poking it
 func (s *sftpAPI) ensureConnected() error {
-	if _, err := s.Client.Getwd(); err != nil {
-		debug("not connected, connecting: %s", err)
+	if !s.Client.Poke() {
+		debug("not connected, connecting")
 		return s.connect()
 	}
 	return nil
diff --git a/api_test.go b/api_test.go
index f27a280..32b87e0 100644
--- a/api_test.go
+++ b/api_test.go
@@ -3,13 +3,19 @@ package main
 import (
 	"bytes"
 	"encoding/binary"
+	"io"
+	"io/ioutil"
+	"net/url"
+	"os"
+	"path/filepath"
+	"sort"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/assert"
 )
 
 func TestBinaryCommands(t *testing.T) {
-	verbose = true
 
 	dat := make([]byte, 4)
 	binary.BigEndian.PutUint32(dat, 42)
@@ -23,7 +29,6 @@ func TestBinaryCommands(t *testing.T) {
 }
 
 func TestTextCommands(t *testing.T) {
-	verbose = true
 
 	src := bytes.NewReader([]byte(`{"command": "PUT", "args": {"arg": "val"}}` + "\n"))
 
@@ -32,3 +37,161 @@ func TestTextCommands(t *testing.T) {
 	assert.Equal(t, "PUT", cmd.Name)
 	assert.Equal(t, map[string]string{"arg": "val"}, cmd.Args)
 }
+
+/*
+
+  This is all fixture code for testing the API
+
+  There must be a less verbose way to do all this.
+
+*/
+type noOpCloseWriter struct {
+	w io.Writer
+}
+
+func (n noOpCloseWriter) Write(buf []byte) (int, error) {
+	return n.w.Write(buf)
+}
+
+func (n noOpCloseWriter) Close() error {
+	return nil
+}
+
+type stubClient struct {
+	putFile *bytes.Buffer
+	getFile io.ReadCloser
+	poke    bool
+	workdir string
+}
+
+func (c stubClient) Cleanup() error {
+	return os.RemoveAll(c.workdir)
+}
+
+func (c stubClient) Close() error {
+	return nil
+}
+
+func (c stubClient) Create(path string) (io.WriteCloser, error) {
+	return noOpCloseWriter{c.putFile}, nil
+}
+
+func (c stubClient) Poke() bool {
+	return c.poke
+}
+
+func (c stubClient) Open(path string) (io.ReadCloser, error) {
+	return c.getFile, nil
+}
+
+func (c stubClient) ReadDir(p string) ([]os.FileInfo, error) {
+	f, err := os.Open(c.workdir)
+	if err != nil {
+		return []os.FileInfo{}, nil
+	}
+	return f.Readdir(-1)
+}
+
+func (c stubClient) Remove(fpath string) error {
+	return nil
+}
+
+func newStubAPI() (sftpAPI, error) {
+	req, resp, err := os.Pipe()
+	if err != nil {
+		return sftpAPI{}, nil
+	}
+
+	dir, err := tempDir()
+	r := bytes.NewReader([]byte("xyz"))
+	c := stubClient{
+		putFile: new(bytes.Buffer),
+		getFile: ioutil.NopCloser(r),
+		poke:    true,
+		workdir: dir,
+	}
+	u, err := url.Parse("sftp://localhost")
+	if err != nil {
+		return sftpAPI{}, nil
+	}
+	return sftpAPI{
+		URL:    u,
+		Client: c,
+		Req:    req,
+		Resp:   resp,
+	}, nil
+}
+
+func tempDir() (string, error) {
+	tmp := os.TempDir()
+	return ioutil.TempDir(tmp, "")
+}
+
+func TestAPIdoGet(t *testing.T) {
+	api, err := newStubAPI()
+	assert.Nil(t, err)
+	defer api.Client.(stubClient).Cleanup()
+
+	dest := filepath.Join(api.Client.(stubClient).workdir, "foo.dat")
+	err = api.doGet("ignored", dest)
+	assert.Nil(t, err)
+
+	dat, err := ioutil.ReadFile(dest)
+	assert.Nil(t, err)
+	assert.Equal(t, []byte("xyz"), dat)
+}
+
+func TestAPIdoPut(t *testing.T) {
+	api, err := newStubAPI()
+	assert.Nil(t, err)
+	defer api.Client.(stubClient).Cleanup()
+
+	// create file to put
+	dest := filepath.Join(api.Client.(stubClient).workdir, "foo.dat")
+	f, err := os.Create(dest)
+	assert.Nil(t, err)
+	_, err = f.Write([]byte("zyx"))
+	assert.Nil(t, err)
+
+	err = api.doPut(dest, "ignored")
+	assert.Nil(t, err)
+
+	assert.Nil(t, err)
+	assert.Equal(t, []byte("zyx"), api.Client.(stubClient).putFile.Bytes())
+}
+
+func TestAPIdoListdir(t *testing.T) {
+	api, err := newStubAPI()
+	assert.Nil(t, err)
+	defer api.Client.(stubClient).Cleanup()
+
+	// create file/dir to put
+	dir := api.Client.(stubClient).workdir
+
+	err = os.Mkdir(filepath.Join(dir, "00_dir"), 0755)
+	assert.Nil(t, err)
+	dest := filepath.Join(dir, "01_file")
+	f, err := os.Create(dest)
+	assert.Nil(t, err)
+	_, err = f.Write([]byte("zyx"))
+	assert.Nil(t, err)
+
+	matches, err := api.doListdir("ignored")
+	assert.Nil(t, err)
+
+	assert.Equal(t, 2, len(matches))
+	sort.Slice(matches, func(i, j int) bool {
+		return matches[i].Name < matches[j].Name
+	})
+
+	assert.Equal(t, "00_dir", matches[0].Name)
+	assert.Equal(t, 16877, matches[0].Mode)
+	// make sure the mtime is within 10s
+	assert.Greater(t, time.Now().Unix()+10, matches[0].MTime)
+
+	assert.Equal(t, "01_file", matches[1].Name)
+	assert.Equal(t, 33188, int(matches[1].Mode))
+	// make sure the mtime is within 10s
+	assert.Greater(t, time.Now().Unix()+10, matches[1].MTime)
+
+}
-- 
GitLab