Commit d529dd6a authored by Bruce Flynn's avatar Bruce Flynn
Browse files

testing

parent e6dd9d8f
......@@ -22,23 +22,78 @@ var (
errBadArgs = errors.New("Invalid arguments")
)
// type yielded from binaryCommands or testCommands
type command struct {
Name string `json:"command"`
Args map[string]string `json:"args"`
}
// result of commands requested by client
type result struct {
Status string `json:"status"`
Message string `json:"message"`
Data interface{} `json:"data"`
}
type sftpClient interface {
Close() error
Create(string) (io.WriteCloser, error)
Getwd() (string, error)
Open(string) (io.ReadCloser, error)
ReadDir(string) ([]os.FileInfo, error)
Remove(string) error
}
type sftpFile struct {
Rem *sftp.File
}
func (f sftpFile) Close() error {
return f.Close()
}
func (f sftpFile) Write(buf []byte) (int, error) {
return f.Rem.Write(buf)
}
type mySFTPClient struct {
SFTP *sftp.Client
}
func (c mySFTPClient) Close() error {
return c.Close()
}
func (c mySFTPClient) Create(path string) (io.WriteCloser, error) {
f, err := c.SFTP.Create(path)
if err != nil {
return sftpFile{}, err
}
return sftpFile{f}, nil
}
func (c mySFTPClient) Getwd() (string, error) {
return c.SFTP.Getwd()
}
func (c mySFTPClient) Open(path string) (io.ReadCloser, error) {
return c.SFTP.Open(path)
}
func (c mySFTPClient) ReadDir(path string) ([]os.FileInfo, error) {
return c.ReadDir(path)
}
func (c mySFTPClient) Remove(path string) error {
return c.SFTP.Remove(path)
}
type sftpAPI struct {
url *url_.URL
client *sftp.Client
cfg ssh.ClientConfig
req io.Reader
resp io.Writer
URL *url_.URL
Client sftpClient
Cfg ssh.ClientConfig
Req io.Reader
Resp io.Writer
}
func readHeader(r io.Reader) (uint32, error) {
......@@ -70,26 +125,29 @@ type protocol interface {
type commandReader func() <-chan command
type resultWriter func(result) error
func (s sftpAPI) binaryCommands() <-chan command {
// binaryCommands commands from the source. The protocol is to first send a header containing
// the size of the command payload then the JSON serialized command in the body of the
// payload.
func binaryCommands(src io.Reader) <-chan command {
ch := make(chan command)
go func() {
for {
num, err := readHeader(s.req)
if err == io.EOF {
debug("EOF from command stream")
num, err := readHeader(src)
if err == io.EOF || err == io.ErrUnexpectedEOF {
debug("EOF reading header")
break
}
// Can't recover if the basic assumptions regarding the protocol are broken
// TODO: or can we?
if err != nil {
info("ERROR bad command, bailing: %s", err)
info("ERROR reading header, bailing: %s", err)
break
}
buf, err := readPayload(s.req, num)
buf, err := readPayload(src, num)
if err != nil {
info("ERROR bad command, bailing!!!: %s", err)
info("ERROR reading payload, bailing!!!: %s", err)
break
}
......@@ -106,11 +164,11 @@ func (s sftpAPI) binaryCommands() <-chan command {
return ch
}
func (s sftpAPI) textCommands() <-chan command {
func textCommands(src io.Reader) <-chan command {
ch := make(chan command)
go func() {
scanner := bufio.NewScanner(s.req)
scanner := bufio.NewScanner(src)
for scanner.Scan() {
line := scanner.Text()
cmd, err := decodeCommand([]byte(line))
......@@ -128,7 +186,7 @@ func (s sftpAPI) textCommands() <-chan command {
func (s sftpAPI) rempath(path string) string {
if !filepath.IsAbs(path) {
return filepath.Join(s.url.Path, path)
return filepath.Join(s.URL.Path, path)
}
return path
}
......@@ -136,7 +194,7 @@ func (s sftpAPI) rempath(path string) string {
// args: source, dest as abs paths
func (s sftpAPI) doPut(source, dest string) error {
dest = s.rempath(dest)
fout, err := s.client.Create(dest)
fout, err := s.Client.Create(dest)
if err != nil {
return errors.Wrapf(err, "can't create %s", dest)
}
......@@ -159,7 +217,7 @@ func (s sftpAPI) doPut(source, dest string) error {
func (s sftpAPI) doGet(source, dest string) error {
source = s.rempath(source)
fin, err := s.client.Open(source)
fin, err := s.Client.Open(source)
if err != nil {
return errors.Wrapf(err, "can't read %s", source)
}
......@@ -190,7 +248,7 @@ func (s sftpAPI) doListdir(path string) ([]stat, error) {
stats := []stat{}
path = s.rempath(path)
infos, err := s.client.ReadDir(path)
infos, err := s.Client.ReadDir(path)
if err != nil {
return stats, err
}
......@@ -210,11 +268,11 @@ func (s sftpAPI) doListdir(path string) ([]stat, error) {
// args: abspath
func (s sftpAPI) doDelete(path string) error {
return s.client.Remove(s.rempath(path))
return s.Client.Remove(s.rempath(path))
}
func (s *sftpAPI) ensureConnected() error {
if _, err := s.client.Getwd(); err != nil {
if _, err := s.Client.Getwd(); err != nil {
debug("not connected, connecting: %s", err)
return s.connect()
}
......@@ -273,7 +331,7 @@ func (s sftpAPI) writeBinaryResult(zult result) error {
binary.BigEndian.PutUint32(buf, uint32(len(dat)))
buf = append(buf, dat...)
_, err = s.resp.Write(buf)
_, err = s.Resp.Write(buf)
return err
}
......@@ -283,26 +341,27 @@ func (s sftpAPI) writeTextResult(zult result) error {
return err
}
dat = append(dat, '\n')
_, err = s.resp.Write(dat)
_, err = s.Resp.Write(dat)
return err
}
func (s *sftpAPI) connect() error {
debug("connecting to %s", s.url.Host)
client, err := ssh.Dial("tcp", s.url.Host, &s.cfg)
debug("connecting to %s", s.URL.Host)
client, err := ssh.Dial("tcp", s.URL.Host, &s.Cfg)
if err != nil {
return err
}
s.client, err = sftp.NewClient(client)
c, err := sftp.NewClient(client)
if err != nil {
return err
}
s.Client = mySFTPClient{c}
return nil
}
func (s sftpAPI) close() {
if s.client != nil {
s.client.Close()
if s.Client != nil {
s.Client.Close()
}
}
......@@ -384,16 +443,16 @@ func newSFTPAPI(url, privateKey, hostKey string, req io.Reader, resp io.Writer)
var err error
sftp := sftpAPI{
req: req,
resp: resp,
Req: req,
Resp: resp,
}
sftp.url, err = cleanURL(url)
sftp.URL, err = cleanURL(url)
if err != nil {
return sftp, err
}
sftp.cfg, err = newSftpConfig(sftp.url.User.Username(), privateKey, hostKey)
sftp.Cfg, err = newSftpConfig(sftp.URL.User.Username(), privateKey, hostKey)
if err != nil {
return sftp, err
}
......
package main
import (
"bytes"
"encoding/binary"
"testing"
"github.com/stretchr/testify/assert"
)
func TestBinaryCommands(t *testing.T) {
verbose = true
dat := make([]byte, 4)
binary.BigEndian.PutUint32(dat, 42)
dat = append(dat, []byte(`{"command": "PUT", "args": {"arg": "val"}}`)...)
src := bytes.NewReader(dat)
cmd := <-binaryCommands(src)
assert.Equal(t, "PUT", cmd.Name)
assert.Equal(t, map[string]string{"arg": "val"}, cmd.Args)
}
func TestTextCommands(t *testing.T) {
verbose = true
src := bytes.NewReader([]byte(`{"command": "PUT", "args": {"arg": "val"}}` + "\n"))
cmd := <-textCommands(src)
assert.Equal(t, "PUT", cmd.Name)
assert.Equal(t, map[string]string{"arg": "val"}, cmd.Args)
}
......@@ -6,5 +6,6 @@ require (
github.com/pkg/errors v0.8.1
github.com/pkg/sftp v1.10.1
github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.4.0
golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7
)
......@@ -93,14 +93,14 @@ https://gitlab.ssec.wisc.edu/brucef/sftper
}
defer sftp.close()
readCommands := sftp.binaryCommands
readCommands := binaryCommands
writeResult := sftp.writeBinaryResult
if *text {
readCommands = sftp.textCommands
readCommands = textCommands
writeResult = sftp.writeTextResult
}
for cmd := range readCommands() {
for cmd := range readCommands(sftp.Req) {
debug("command: %+v", cmd)
zult := sftp.doCommand(cmd)
debug("result %+v", zult)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment