Commit 8fc87c2b authored by Bruce Flynn's avatar Bruce Flynn

fix tests and errors they revealed

parent a30225c4
......@@ -9,6 +9,8 @@ the file.
import argparse
import logging
from sqlalchemy.exc import IntegrityError
from amrc_aws import parsers, db, exc
LOG = logging
......@@ -52,7 +54,7 @@ def main():
try:
frames = parser.frames(filepath, error_callback=parse_error)
db.insert_frames(frames)
except (db.IntegrityError, exc.ParseError, ValueError) as err:
except (IntegrityError, exc.ParseError, ValueError) as err:
LOG.warning('File insert failed(%s): %s', filepath, err.__class__.__name__)
LOG.info('updating station dates')
......
......@@ -9,6 +9,8 @@ cannot be parsed will be skipped.
import argparse
import logging
from sqlalchemy.exc import IntegrityError
from amrc_aws import parsers, db
LOG = logging
......@@ -38,7 +40,7 @@ def main():
inserted = db.insert_frames([frame])
if inserted == 0:
LOG.warning('No values inserted for line %d', idx + 1)
except db.IntegrityError:
except IntegrityError:
LOG.warning('Line %d insert failed, possibly already exists', idx + 1)
......
......@@ -149,7 +149,7 @@ def _fill_dataset(stations, symbols, data, nc):
# data[0:data.shape[0]:2,2]
for symidx, symbol in enumerate(symbols):
# prefill with -999.9
arr = np.ones((len(stations), data.shape[0] / len(stations))) * -999.0
arr = np.ones((len(stations), int(data.shape[0] / len(stations)))) * -999.0
for staidx, station in enumerate(stations):
# to get the rows for a particular station
# staidx:data.shape[0]:num_stations
......@@ -166,7 +166,7 @@ def _fill_dataset(stations, symbols, data, nc):
var[:] = np.array(stations)
# set times
var = nc.variables['time']
var[:] = np.array([unixtime(d) for d in data[0:data.shape[0]:2, 0]], dtype=np.uint32)
var[:] = np.array([unixtime(d.replace(tzinfo=None)) for d in data[0:data.shape[0]:2, 0]], dtype=np.uint32)
# nc.variables['lat'][:] = [43.0, 43.0]
# nc.variables['lon'][:] = [-89.0, -89.0]
......
import os
from subprocess import call, check_call
from datetime import timedelta
from webtest import TestApp
import pytest
import sqlalchemy as sa
from pyramid import testing
from sqlalchemy.engine.url import make_url
from sqlalchemy.pool import NullPool
from amrc_aws import db
from amrc_aws.app import main
pytestmark = pytest.mark.skipif('TEST_DBURL' not in os.environ,
reason='TEST_DBURL not set')
@pytest.fixture()
def database():
dburl = os.environ['TEST_DBURL']
url = make_url(dburl)
assert url.database == 'docker' or url.database.startswith('__')
hostopt = '-h %s' % url.host if url.host else ''
portopt = '-p %d' % url.port if url.port else ''
useropt = '-U %s' % url.username if url.username else ''
passwdopt = '-W {}'.format(url.password) if url.password else ''
opts = '{} {} {} {} {}'.format(hostopt, portopt, useropt, passwdopt, url.database)
check_call("dropdb --if-exists {}; ".format(opts), shell=True)
check_call("createdb {}".format(opts), shell=True)
eng = sa.create_engine(dburl, poolclass=NullPool)
db.init(eng)
db.create_schema()
yield eng
eng.dispose()
db.engine = None
db.metadata = None
@pytest.fixture()
def dbdata(database):
sqlpath = os.path.join(os.path.dirname(__file__), 'data.sql')
conn = database.raw_connection()
cur = conn.cursor()
try:
cur.execute(open(sqlpath).read())
conn.commit()
finally:
conn.close()
stamps = database.execute('SELECT min(stamp), max(stamp) FROM data').fetchone()
stations = database.execute('SELECT distinct(station) as station FROM data').fetchall()
return (stamps.min, stamps.max + timedelta(seconds=1)), [r.station for r in stations]
@pytest.fixture()
def wsgiapp():
testing.setUp()
app = main({})
yield TestApp(app)
testing.tearDown()
This diff is collapsed.
import csv
import os
from tempfile import TemporaryDirectory
import pytest
from netCDF4 import Dataset
from amrc_aws.util import unixtime
from amrc_aws.app import symbol_names
def api_fmt(d):
return d.strftime('%Y-%m-%dT%H:%M:%SZ')
@pytest.mark.parametrize('format', ['rdr', 'q1h'])
def test_data_nc(format, wsgiapp, dbdata):
(start, end), stations = dbdata
params = {'start': api_fmt(start),
'end': api_fmt(end),
'format': format,
'stations': '|'.join(stations)}
with TemporaryDirectory() as tmpdir:
response = wsgiapp.get('/data.nc', params=params, status=200)
# write netcdf file to disk
filepath = os.path.join(tmpdir, 'test.nc')
with open(filepath, 'wb') as fptr:
fptr.write(response.body)
# verify it looks correct
with Dataset(filepath) as ds:
time = ds['time'][:]
assert not hasattr(time, 'mask'), 'Time data should not have a mask'
assert time[0] >= unixtime(start)
assert time[-1] <= unixtime(end)
for sym in symbol_names:
assert sym in ds.variables
def test_data_csv(wsgiapp, dbdata):
(start, end), stations = dbdata
params = {'start': api_fmt(start),
'end': api_fmt(end),
'stations': '|'.join(stations)}
with TemporaryDirectory() as tmpdir:
response = wsgiapp.get('/data.csv', params=params, status=200)
# write netcdf file to disk
filepath = os.path.join(tmpdir, 'test.csv')
with open(filepath, 'wb') as fptr:
fptr.write(response.body)
import os
from datetime import datetime, timedelta
import pytest
import sqlalchemy as sa
from amrc_aws import db, config
# has to be assigned "pytestmark" name so entire module gets skipped
# http://doc.pytest.org/en/latest/skipping.html#skip-all-test-functions-of-a-class-or-module
pytestmark = pytest.mark.skipif(not os.getenv('AWSAPI_DB_URL'),
reason='required AWSAPI_DB_URL')
@pytest.fixture
def database():
schema = config.DB_SCHEMA
engine = sa.create_engine(os.environ['AWSAPI_DB_URL'])
database_name = engine.url.database
assert database_name.startswith('__') and 'test' in database_name
with engine.connect() as conn:
conn.execute("DROP SCHEMA IF EXISTS {:s} CASCADE".format(schema))
conn.execute("CREATE SCHEMA {:s}".format(schema))
engine.dispose()
db.init(sa.create_engine(os.environ['AWSAPI_DB_URL']))
db.create_schema()
with db.connection() as conn:
conn.execute("INSERT INTO station_dates(station, stamp) VALUES ('Byrd', now())")
return sa.create_engine(os.environ['AWSAPI_DB_URL'])
def test_get_stations(database):
stations = db.get_stations()
assert len(stations) == 1
assert 'Byrd' in stations
now = datetime.utcnow()
oneday = timedelta(days=1)
stations = db.get_stations(now - oneday, now + oneday)
assert len(stations) == 1
tendays = timedelta(days=10)
stations = db.get_stations(now + tendays, now + tendays)
assert len(stations) == 0
def test_get_rdr_slice(database):
with db.connection() as conn:
station = db.get_station_names()[0]
conn.execute(sa.text("""
INSERT INTO data VALUES
('2016-01-01 00:00:00', :station_id, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
('2016-01-01 00:10:00', :station_id, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0),
('2016-01-01 00:20:00', :station_id, 2.0, 2.0, 2.0, 2.0, 0.0, 0.0),
('2016-01-01 00:30:00', :station_id, 3.0, 3.0, 3.0, 3.0, 0.0, 0.0)
"""), station_id=station)
avg = None
start = datetime(2016, 1, 1)
end = datetime(2016, 1, 1, 0, 59, 59)
data = db.get_rdr_slice(['STATION1'], ['air_temp'], start, end, avg)
assert len(data) == 6
avg = 600
data = db.get_rdr_slice(['STATION1'], ['air_temp'], start, end, avg)
assert len(data) == 6
avg = 1200
data = db.get_rdr_slice(['STATION1'], ['air_temp'], start, end, avg)
assert len(data) == 3
def test_get_q1h_slice(database):
with db.connection() as conn:
station = db.get_station_names()[0]
conn.execute(sa.text("""
INSERT INTO data_q1h VALUES
('2016-01-01 00:00:00', :station_id, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0),
('2016-01-01 00:10:00', :station_id, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0),
('2016-01-01 00:20:00', :station_id, 2.0, 2.0, 2.0, 2.0, 0.0, 0.0),
('2016-01-01 00:30:00', :station_id, 3.0, 3.0, 3.0, 3.0, 0.0, 0.0)
"""), station_id=station)
avg = 300
start = datetime(2016, 1, 1)
end = datetime(2016, 1, 1, 0, 59, 59)
data = db.get_q1h_slice(['STATION1'], ['air_temp'], start, end, avg)
# data should return 12 rows becasue there are 12 5min intervals in 1 hour
assert len(data) == 12
avg = 600
data = db.get_q1h_slice(['STATION1'], ['air_temp'], start, end, avg)
# data should return 12 rows becasue there are 12 5min intervals in 1 hour
assert len(data) == 6
......@@ -2,33 +2,13 @@ import os
from subprocess import check_call
import pytest
import sqlalchemy as sa
from amrc_aws import config
from amrc_aws import db
# has to be assigned "pytestmark" name so entire module gets skipped
# http://doc.pytest.org/en/latest/skipping.html#skip-all-test-functions-of-a-class-or-module
pytestmark = pytest.mark.skipif(not os.getenv('AWSAPI_DB_URL'),
reason='required AWSAPI_DB_URL')
@pytest.fixture
def database():
schema = config.DB_SCHEMA
engine = sa.create_engine(os.environ['AWSAPI_DB_URL'])
database_name = engine.url.database
assert database_name.startswith('__') and 'test' in database_name
with engine.connect() as conn:
conn.execute("DROP SCHEMA IF EXISTS {:s} CASCADE".format(schema))
conn.execute("CREATE SCHEMA {:s}".format(schema))
conn.execute("COMMIT")
engine.dispose()
db.init(sa.create_engine(os.environ['AWSAPI_DB_URL']))
db.create_schema()
return sa.create_engine(os.environ['AWSAPI_DB_URL'])
pytestmark = pytest.mark.skipif(not os.getenv('TEST_DBURL'),
reason='required TEST_DBURL')
@pytest.fixture
......@@ -49,11 +29,13 @@ Lat: 76.15S Lon: 168.40E Elev: 262m
def test_file_insert_q1h(database, q1h_file):
env = os.environ.copy()
env['AWSAPI_DB_URL'] = env['TEST_DBURL']
script = 'python -m amrc_aws.cli.file_insert -v {}'.format(q1h_file)
check_call(script, shell=True)
check_call(script, env=env, shell=True)
with db.connection() as conn:
count = conn.execute('select count(*) from data_q1h').fetchone().count
count = conn.execute('SELECT count(*) FROM data_q1h').fetchone().count
assert count == 6
......@@ -75,41 +57,52 @@ Lat : 82.32S Long : 75.99E Elev : 4027 M
def test_file_insert_rdr(database, rdr_file):
env = os.environ.copy()
env['AWSAPI_DB_URL'] = env['TEST_DBURL']
script = 'python -m amrc_aws.cli.file_insert -v {}'.format(rdr_file)
check_call(script, shell=True)
check_call(script, env=env, shell=True)
with db.connection() as conn:
count = conn.execute('select count(*) from data').fetchone().count
count = conn.execute('SELECT count(*) FROM data').fetchone().count
assert count == 6
def test_file_insert_skips_duplicates(database, rdr_file):
env = os.environ.copy()
env['AWSAPI_DB_URL'] = env['TEST_DBURL']
script = 'python -m amrc_aws.cli.file_insert -v {}'.format(rdr_file)
check_call(script, shell=True)
check_call(script, env=env, shell=True)
script = 'python -m amrc_aws.cli.file_insert -v {}'.format(rdr_file)
check_call(script, shell=True)
check_call(script, env=env, shell=True)
with db.connection() as conn:
count = conn.execute('select count(*) from data').fetchone().count
count = conn.execute('SELECT count(*) FROM data').fetchone().count
assert count == 6
def test_batch_insert(database, rdr_file, tmpdir):
env = os.environ.copy()
env['AWSAPI_DB_URL'] = env['TEST_DBURL']
script = 'python -m amrc_aws.cli.batch_insert --rdr -v -'
fptr = tmpdir.join('files.txt')
fptr.write(rdr_file + '\n')
check_call(script, shell=True, stdin=open(fptr.strpath))
check_call(script, env=env, shell=True, stdin=open(fptr.strpath))
with db.connection() as conn:
count = conn.execute('select count(*) from data').fetchone().count
count = conn.execute('SELECT count(*) FROM data').fetchone().count
assert count == 6
def test_refresh_stations(database, rdr_file, tmpdir):
env = os.environ.copy()
env['AWSAPI_DB_URL'] = env['TEST_DBURL']
script = 'python -m amrc_aws.cli.refresh_stations -v'
check_call(script, shell=True)
check_call(script, env=env, shell=True)
with db.connection() as conn:
count = conn.execute('select count(*) from station_dates').fetchone().count
count = conn.execute('SELECT count(*) FROM station_dates').fetchone().count
assert count == 0
import os
from datetime import datetime
import pytest
from amrc_aws import db
pytestmark = pytest.mark.skipif(not os.getenv('TEST_DBURL'),
reason='required TEST_DBURL')
def test_get_stations(dbdata):
(start, end), stations = dbdata
assert db.get_stations() == stations
assert db.get_stations(start, end) == stations
assert db.get_stations(datetime(1970, 1, 1), datetime(1970, 1, 2)) == []
@pytest.mark.parametrize(
'format,interval', [
('rdr', 600),
('q1h', 3600)
])
def test_get_rdr_slice(format, interval, dbdata):
(start, end), stations = dbdata
get_slice = getattr(db, 'get_%s_slice' % format)
avg = None
data = get_slice([stations[0]], ['air_temp'], start, end, avg)
assert len(data) == 86400 / interval
avg = interval * 2
data = get_slice([stations[0]], ['air_temp'], start, end, avg)
assert len(data) == 86400 / (interval * 2)
avg = interval * 3
data = get_slice([stations[0]], ['air_temp'], start, end, avg)
assert len(data) == 86400 / (interval * 3)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment