diff --git a/setup.py b/setup.py
index 7ff26fb697b79864a7e8acbce2a97660f3ecf201..378dc9c56507289578ae4e254ded5460bbbf9874 100644
--- a/setup.py
+++ b/setup.py
@@ -190,7 +190,7 @@ setup(
                  "Topic :: Scientific/Engineering"],
     zip_safe=False,
     include_package_data=True,
-    install_requires=['numpy', 'pillow', 'numba', 'vispy>=0.6.0,<0.7.0',
+    install_requires=['numpy', 'pillow', 'numba', 'vispy>=0.7.1',
                       'netCDF4', 'h5py', 'pyproj',
                       'pyshp', 'shapely', 'rasterio', 'sqlalchemy',
                       'appdirs', 'pyyaml', 'pyqtgraph', 'satpy', 'matplotlib',
diff --git a/uwsift/__main__.py b/uwsift/__main__.py
index ec8d6ece610b9b77e622080f509fb2b154abad80..1da9fb052a398c5ed4aef16399c2286c7193f5f9 100644
--- a/uwsift/__main__.py
+++ b/uwsift/__main__.py
@@ -402,17 +402,17 @@ class UserControlsAnimation(QtCore.QObject):
             self.ui.statusbar.showMessage("ERROR: Layer with time steps or band siblings needed", STATUS_BAR_DURATION)
         LOG.info('using siblings of {} for animation loop'.format(uuids[0] if uuids else '-unknown-'))
 
-    def toggle_animation(self, action: QtGui.QAction = None, *args):
+    def toggle_animation(self, action: QtWidgets.QAction = None, *args):
         """Toggle animation on/off."""
         new_state = self.scene_manager.layer_set.toggle_animation()
         self.ui.animPlayPause.setChecked(new_state)
 
 
-class Main(QtGui.QMainWindow):
+class Main(QtWidgets.QMainWindow):
     _last_open_dir: str = None  # directory to open files in
-    _recent_files_menu: QtGui.QMenu = None  # QMenu
-    _open_cache_dialog: QtGui.QDialog = None
-    _screenshot_dialog: QtGui.QDialog = None
+    _recent_files_menu: QtWidgets.QMenu = None  # QMenu
+    _open_cache_dialog: QtWidgets.QDialog = None
+    _screenshot_dialog: QtWidgets.QDialog = None
     _cmap_editor = None  # Gradient editor widget
     _resource_collector: ResourceSearchPathCollector = None
     _resource_collector_timer: QtCore.QTimer = None
@@ -639,7 +639,7 @@ class Main(QtGui.QMainWindow):
         gv = self.ui.timelineView
 
         # set up the widget itself
-        gv.setViewportUpdateMode(QtGui.QGraphicsView.FullViewportUpdate)
+        gv.setViewportUpdateMode(QtWidgets.QGraphicsView.FullViewportUpdate)
         gv.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
         gv.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAsNeeded)
         # gv.setRenderHints(QtGui.QPainter.Antialiasing)
@@ -935,7 +935,7 @@ class Main(QtGui.QMainWindow):
             LOG.debug("Wizard closed, nothing to load")
         self._wizard_dialog = None
 
-    def remove_region_polygon(self, action: QtGui.QAction = None, *args):
+    def remove_region_polygon(self, action: QtWidgets.QAction = None, *args):
         if self.scene_manager.has_pending_polygon():
             self.scene_manager.clear_pending_polygon()
             return
@@ -945,7 +945,7 @@ class Main(QtGui.QMainWindow):
         LOG.info("Clearing polygon with name '%s'", removed_name)
         self.scene_manager.remove_polygon(removed_name)
 
-    def create_algebraic(self, action: QtGui.QAction = None, uuids=None, composite_type=CompositeType.ARITHMETIC):
+    def create_algebraic(self, action: QtWidgets.QAction = None, uuids=None, composite_type=CompositeType.ARITHMETIC):
         if uuids is None:
             uuids = list(self.layer_list_model.current_selected_uuids())
         dialog = CreateAlgebraicDialog(self.document, uuids, parent=self)
diff --git a/uwsift/view/colormap_editor.py b/uwsift/view/colormap_editor.py
index 7ee1f934d76331e9eb099c5a9b0d9286e83c3ea1..9a1bda99c65783f65dc1b7451f32dac0f6393a4f 100644
--- a/uwsift/view/colormap_editor.py
+++ b/uwsift/view/colormap_editor.py
@@ -294,7 +294,7 @@ class ColormapEditor(QtWidgets.QDialog):
 
             for cmap_name in cmap_content:
                 if cmap_name in self.builtin_colormap_states:
-                    QtGui.QMessageBox.information(
+                    QtWidgets.QMessageBox.information(
                         self, "Error", "You cannot import a colormap with "
                                        "the same name as one of the internal "
                                        "colormaps: {}".format(cmap_name))
diff --git a/uwsift/view/open_file_wizard.py b/uwsift/view/open_file_wizard.py
index 8a47d0974bf3e4e5a88f8c691400cde3e8c59d83..4b2039cd56182ffc4f0f713cebeeba010f059662 100644
--- a/uwsift/view/open_file_wizard.py
+++ b/uwsift/view/open_file_wizard.py
@@ -58,7 +58,7 @@ class OpenFileWizard(QtWidgets.QWizard):
         self.file_groups = {}
         self.unknown_files = set()
         app = QtWidgets.QApplication.instance()
-        self._unknown_icon = app.style().standardIcon(QtGui.QStyle.SP_DialogCancelButton)
+        self._unknown_icon = app.style().standardIcon(QtWidgets.QStyle.SP_DialogCancelButton)
         self._known_icon = QtGui.QIcon()
         # self._known_icon = app.style().standardIcon(QtGui.QStyle.SP_DialogApplyButton)
 
diff --git a/uwsift/view/scene_graph.py b/uwsift/view/scene_graph.py
index 46ea059db743b9661d8aa2a84f3200e49a428fcb..d29437e3fdebd51f74b35564b9d667eab2204be4 100644
--- a/uwsift/view/scene_graph.py
+++ b/uwsift/view/scene_graph.py
@@ -624,7 +624,7 @@ class SceneGraphManager(QObject):
             clim=(0., 1.),
             gamma=1.,
             interpolation='nearest',
-            method='tiled',
+            method='subdivide',
             cmap=self.document.find_colormap('grays'),
             double=False,
             texture_shape=DEFAULT_TEXTURE_SHAPE,
@@ -976,7 +976,7 @@ class SceneGraphManager(QObject):
             clim=p.climits,
             gamma=p.gamma,
             interpolation='nearest',
-            method='tiled',
+            method='subdivide',
             cmap=self.document.find_colormap(p.colormap),
             double=False,
             texture_shape=DEFAULT_TEXTURE_SHAPE,
@@ -1012,7 +1012,7 @@ class SceneGraphManager(QObject):
                 clim=p.climits,
                 gamma=p.gamma,
                 interpolation='nearest',
-                method='tiled',
+                method='subdivide',
                 cmap=None,
                 double=False,
                 texture_shape=DEFAULT_TEXTURE_SHAPE,
diff --git a/uwsift/view/test_visuals.py b/uwsift/view/test_visuals.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad9be1caddb41f8f3ce925a251300f05da0e7377
--- /dev/null
+++ b/uwsift/view/test_visuals.py
@@ -0,0 +1,54 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""Tests for the MultiChannelImageVisual."""
+
+from uwsift.view.visuals import MultiChannelImage
+
+import numpy as np
+from vispy.testing import TestingCanvas, run_tests_if_main, requires_application
+
+
+@requires_application()
+def test_multiband_visual():
+    size = (400, 600)
+    with TestingCanvas(size=size) as c:
+        r_data = np.random.rand(*size)
+        g_data = np.random.rand(*size)
+        b_data = np.random.rand(*size)
+        image = MultiChannelImage(
+            [r_data, None, None],
+            parent=c.scene)
+
+        # Assign only R
+        result = c.render()
+        r_result = result[..., 0]
+        g_result = result[..., 1]
+        b_result = result[..., 2]
+        assert not np.allclose(r_result, 0)
+        np.testing.assert_allclose(g_result, 0)
+        np.testing.assert_allclose(b_result, 0)
+
+        # Add B
+        image.set_data([r_data, None, b_data])
+        image.clim = ("auto", "auto", "auto")
+        result = c.render()
+        r_result = result[..., 0]
+        g_result = result[..., 1]
+        b_result = result[..., 2]
+        assert not np.allclose(r_result, 0)
+        np.testing.assert_allclose(g_result, 0)
+        assert not np.allclose(b_result, 0)
+
+        # Unset R, add G
+        image.set_data([None, g_data, b_data])
+        image.clim = ("auto", "auto", "auto")
+        result = c.render()
+        r_result = result[..., 0]
+        g_result = result[..., 1]
+        b_result = result[..., 2]
+        np.testing.assert_allclose(r_result, 0)
+        assert not np.allclose(g_result, 0)
+        assert not np.allclose(b_result, 0)
+
+
+run_tests_if_main()
diff --git a/uwsift/view/texture_atlas.py b/uwsift/view/texture_atlas.py
index 206f60b834e84e99e21e446313c2e546dd1d4080..5ce8f60f91980235c50d48e2d1df3104690652ee 100644
--- a/uwsift/view/texture_atlas.py
+++ b/uwsift/view/texture_atlas.py
@@ -20,9 +20,10 @@ REQUIRES
 
 import logging
 import os
+import warnings
 
 import numpy as np
-from vispy.gloo import Texture2D
+from vispy.visuals._scalable_textures import GPUScaledTexture2D
 
 from uwsift.common import DEFAULT_TILE_HEIGHT, DEFAULT_TILE_WIDTH
 
@@ -34,26 +35,37 @@ __docformat__ = 'reStructuredText'
 LOG = logging.getLogger(__name__)
 
 
-class TextureAtlas2D(Texture2D):
+class TextureAtlas2D(GPUScaledTexture2D):
     """A 2D Texture Array structure implemented as a 2D Texture Atlas.
     """
 
-    def __init__(self, texture_shape, tile_shape=(DEFAULT_TILE_HEIGHT, DEFAULT_TILE_WIDTH),
-                 format=None, resizable=True,
-                 interpolation=None, wrapping=None,
-                 internalformat=None, resizeable=None):
-        assert len(texture_shape) == 2
+    def __init__(self, texture_shape,
+                 tile_shape=(DEFAULT_TILE_HEIGHT, DEFAULT_TILE_WIDTH),
+                 **texture_kwargs):
         # Number of tiles in each direction (y, x)
-        self.texture_shape = texture_shape
+        self.texture_shape = self._check_texture_shape(texture_shape)
         # Number of rows and columns for each tile
         self.tile_shape = tile_shape
         # Number of rows and columns to hold all of these tiles in one texture
         shape = (self.texture_shape[0] * self.tile_shape[0], self.texture_shape[1] * self.tile_shape[1])
         self.texture_size = shape
-        self._fill_array = np.tile(np.nan, self.tile_shape).astype(np.float32)
+        self._fill_array = np.tile(np.float32(np.nan), self.tile_shape)
+        # create a representative array so the texture can be initialized properly with the right dtype
+        rep_arr = np.zeros((10, 10), dtype=np.float32)
         # will add self.shape:
-        super(TextureAtlas2D, self).__init__(None, format, resizable, interpolation,
-                                             wrapping, shape, internalformat, resizeable)
+        super(TextureAtlas2D, self).__init__(data=rep_arr, **texture_kwargs)
+        # GPUScaledTexture2D always uses a "representative" size
+        # we need to force the shape to our final size so we can start setting tiles right away
+        self._resize(shape)
+
+    def _check_texture_shape(self, texture_shape):
+        if isinstance(texture_shape, tuple):
+            if len(texture_shape) != 2:
+                raise ValueError("A shape tuple must be two elements.")
+            texture_shape = texture_shape
+        else:
+            texture_shape = texture_shape.shape
+        return texture_shape
 
     def _tex_offset(self, idx):
         """Return the X, Y texture index offset for the 1D tile index.
@@ -88,4 +100,139 @@ class TextureAtlas2D(Texture2D):
             data[-5:, :] = 1000.
             data[:, :5] = 1000.
             data[:, -5:] = 1000.
-        super(TextureAtlas2D, self).set_data(data, offset=offset, copy=copy)
+        super(TextureAtlas2D, self).scale_and_set_data(data, offset=offset, copy=copy)
+
+
+class MultiChannelGPUScaledTexture2D:
+    """Wrapper class around indiviual textures.
+
+    This helper class allows for easier handling of multiple textures that
+    represent individual R, G, and B channels of an image.
+
+    """
+    _singular_texture_class = GPUScaledTexture2D
+    _ndim = 2
+
+    def __init__(self, data, **texture_kwargs):
+        # data to sent to texture when not being used
+        self._fill_arr = np.full((10, 10), np.float32(np.nan),
+                                 dtype=np.float32)
+
+        self.num_channels = len(data)
+        data = [x if x is not None else self._fill_arr for x in data]
+        self._textures = self._create_textures(self.num_channels, data,
+                                               **texture_kwargs)
+
+    def _create_textures(self, num_channels, data, **texture_kwargs):
+        return [
+            self._singular_texture_class(data[i], **texture_kwargs)
+            for i in range(num_channels)
+        ]
+
+    @property
+    def textures(self):
+        return self._textures
+
+    @property
+    def clim(self):
+        """Get color limits used when rendering the image (cmin, cmax)."""
+        return tuple(t.clim for t in self._textures)
+
+    def set_clim(self, clim):
+        if isinstance(clim, str) or len(clim) == 2:
+            clim = [clim] * self.num_channels
+
+        need_tex_upload = False
+        for tex, single_clim in zip(self._textures, clim):
+            if single_clim is None or single_clim[0] is None:
+                single_clim = (0, 0)  # let VisPy decide what to do with unusable clims
+            if tex.set_clim(single_clim):
+                need_tex_upload = True
+        return need_tex_upload
+
+    @property
+    def clim_normalized(self):
+        return tuple(tex.clim_normalized for tex in self._textures)
+
+    @property
+    def internalformat(self):
+        return self._textures[0].internalformat
+
+    @internalformat.setter
+    def internalformat(self, value):
+        for tex in self._textures:
+            tex.internalformat = value
+
+    @property
+    def interpolation(self):
+        return self._textures[0].interpolation
+
+    @interpolation.setter
+    def interpolation(self, value):
+        for tex in self._textures:
+            self._texture.interpolation = value
+
+    def check_data_format(self, data_arrays):
+        if len(data_arrays) != self.num_channels:
+            raise ValueError(f"Expected {self.num_channels} number of channels, got {len(data_arrays)}.")
+        for tex, data in zip(self._textures, data_arrays):
+            if data is not None:
+                tex.check_data_format(data)
+
+    def scale_and_set_data(self, data, offset=None, copy=False):
+        """Scale and set data for one or all sub-textures.
+
+        Parameters
+        ----------
+        data : list | ndarray
+            Texture data in the form of a numpy array or as a list of numpy
+            arrays. If a list is provided then it must be the same length as
+            ``num_channels`` for this texture. If a numpy array is provided
+            then ``offset`` should also be provided with the first value
+            representing which sub-texture to update. For example,
+            ``offset=(1, 0, 0)`` would update the entire the second (index 1)
+            sub-texture with an offset of ``(0, 0)``. The list can also contain
+            ``None`` to not update the sub-texture at that index.
+        offset: tuple | None
+            Offset into the texture where to write the provided data. If
+            ``None`` then data will be written with no offset (0). If
+            provided as a 2-element tuple then that offset will be used
+            for all sub-textures. If a 3-element tuple then the first offset
+            index represents the sub-texture to update.
+
+        """
+        is_multi = isinstance(data, (list, tuple))
+        index_provided = offset is not None and len(offset) == self._ndim + 1
+        if not is_multi and not index_provided:
+            raise ValueError("Setting texture data for a single sub-texture "
+                             "requires 'offset' to be passed with the first "
+                             "element specifying the sub-texture index.")
+        elif is_multi and index_provided:
+            warnings.warn("Multiple texture arrays were passed, but so was "
+                          "sub-texture index in 'offset'. Ignoring that index.", UserWarning)
+            offset = offset[1:]
+        if is_multi and len(data) != self.num_channels:
+            raise ValueError("Multiple provided arrays must match number of channels. "
+                             f"Got {len(data)}, expected {self.num_channels}.")
+
+        if offset is not None and len(offset) == self._ndim + 1:
+            tex_indexes = offset[:1]
+            offset = offset[1:]
+            data = [data]
+        else:
+            tex_indexes = range(self.num_channels)
+
+        for tex_idx, _data in zip(tex_indexes, data):
+            if _data is None:
+                _data = self._fill_arr
+            self._textures[tex_idx].scale_and_set_data(_data, offset=offset, copy=copy)
+
+
+class MultiChannelTextureAtlas2D(MultiChannelGPUScaledTexture2D):
+    """Helper texture for working with RGB images in SIFT."""
+
+    _singular_texture_class = TextureAtlas2D
+
+    def set_tile_data(self, tile_idx, data_arrays, copy=False):
+        for idx, data in enumerate(data_arrays):
+            self._textures[idx].set_tile_data(tile_idx, data, copy=copy)
diff --git a/uwsift/view/visuals.py b/uwsift/view/visuals.py
index 713037bf1378446bcfbf7b651fd3677c080ebd6d..d40b94a5efa6df738d1f8d54619b49c58bc58856 100644
--- a/uwsift/view/visuals.py
+++ b/uwsift/view/visuals.py
@@ -27,14 +27,11 @@ from datetime import datetime
 
 import numpy as np
 import shapefile
-from vispy.ext.six import string_types
-from vispy.gloo import VertexBuffer
-from vispy.io.datasets import load_spatial_filters
 from vispy.scene.visuals import create_visual_node
 from vispy.visuals import LineVisual, ImageVisual, IsocurveVisual
 # The below imports are needed because we subclassed the ImageVisual
-from vispy.visuals.shaders import Function
-from vispy.visuals.transforms import NullTransform
+from vispy.visuals.shaders import Function, FunctionChain
+from vispy.gloo.texture import should_cast_to_f32
 
 
 from uwsift.common import (
@@ -46,7 +43,7 @@ from uwsift.common import (
     TESS_LEVEL,
     Box, Point, Resolution, ViewBox,
 )
-from uwsift.view.texture_atlas import TextureAtlas2D, Texture2D
+from uwsift.view.texture_atlas import TextureAtlas2D, MultiChannelTextureAtlas2D, MultiChannelGPUScaledTexture2D
 from uwsift.view.tile_calculator import TileCalculator, calc_pixel_size, get_reference_points
 
 __author__ = 'rayg'
@@ -68,96 +65,6 @@ class ArrayProxy(object):
         self.shape = shape
 
 
-VERT_SHADER = """
-uniform int method;  // 0=subdivide, 1=impostor
-attribute vec2 a_position;
-attribute vec2 a_texcoord;
-varying vec2 v_texcoord;
-
-void main() {
-    v_texcoord = a_texcoord;
-    gl_Position = $transform(vec4(a_position, 0., 1.));
-}
-"""
-
-FRAG_SHADER = """
-uniform vec2 image_size;
-uniform int method;  // 0=subdivide, 1=impostor
-uniform sampler2D u_texture;
-varying vec2 v_texcoord;
-
-vec4 map_local_to_tex(vec4 x) {
-    // Cast ray from 3D viewport to surface of image
-    // (if $transform does not affect z values, then this
-    // can be optimized as simply $transform.map(x) )
-    vec4 p1 = $transform(x);
-    vec4 p2 = $transform(x + vec4(0, 0, 0.5, 0));
-    p1 /= p1.w;
-    p2 /= p2.w;
-    vec4 d = p2 - p1;
-    float f = p2.z / d.z;
-    vec4 p3 = p2 - d * f;
-
-    // finally map local to texture coords
-    return vec4(p3.xy / image_size, 0, 1);
-}
-
-
-void main()
-{
-    vec2 texcoord;
-    if( method == 0 ) {
-        texcoord = v_texcoord;
-    }
-    else {
-        // vertex shader ouptuts clip coordinates;
-        // fragment shader maps to texture coordinates
-        texcoord = map_local_to_tex(vec4(v_texcoord, 0, 1)).xy;
-    }
-
-    gl_FragColor = $color_transform($get_data(texcoord));
-}
-"""  # noqa
-
-_null_color_transform = 'vec4 pass(vec4 color) { return color; }'
-_c2l = 'float cmap(vec4 color) { return (color.r + color.g + color.b) / 3.; }'
-
-_interpolation_template = """
-    #include "misc/spatial-filters.frag"
-    vec4 texture_lookup_filtered(vec2 texcoord) {
-        if(texcoord.x < 0.0 || texcoord.x > 1.0 ||
-        texcoord.y < 0.0 || texcoord.y > 1.0) {
-            discard;
-        }
-        return %s($texture, $shape, texcoord);
-    }"""
-
-_texture_lookup = """
-    vec4 texture_lookup(vec2 texcoord) {
-        if(texcoord.x < 0.0 || texcoord.x > 1.0 ||
-        texcoord.y < 0.0 || texcoord.y > 1.0) {
-            discard;
-        }
-        vec4 val = texture2D($texture, texcoord);
-        // http://stackoverflow.com/questions/11810158/how-to-deal-with-nan-or-inf-in-opengl-es-2-0-shaders
-        if (!(val.r <= 0.0 || 0.0 <= val.r)) {
-            discard;
-        }
-
-        if ($vmin < $vmax) {
-            val.r = clamp(val.r, $vmin, $vmax);
-        } else {
-            val.r = clamp(val.r, $vmax, $vmin);
-        }
-        val.r = (val.r-$vmin)/($vmax-$vmin);
-        val.r = pow(val.r, $gamma);
-        val.g = val.r;
-        val.b = val.r;
-
-        return val;
-    }"""
-
-
 class TextureTileState(object):
     """Object to hold the state of the current tile texture.
 
@@ -243,22 +150,57 @@ class TextureTileState(object):
         return ttile_idx
 
 
-class TiledGeolocatedImageVisual(ImageVisual):
-    def __init__(self, data, origin_x, origin_y, cell_width, cell_height,
-                 shape=None,
+class SIFTTiledGeolocatedMixin:
+    def __init__(self, data, *area_params,
                  tile_shape=(DEFAULT_TILE_HEIGHT, DEFAULT_TILE_WIDTH),
                  texture_shape=(DEFAULT_TEXTURE_HEIGHT, DEFAULT_TEXTURE_WIDTH),
                  wrap_lon=False, projection=DEFAULT_PROJECTION,
-                 cmap='viridis', method='tiled', clim='auto', gamma=1.,
-                 interpolation='nearest', **kwargs):
-        if method != 'tiled':
-            raise ValueError("Only 'tiled' method is currently supported")
-        method = 'subdivide'
-        grid = (1, 1)
+                 **visual_kwargs):
+        origin_x, origin_y, cell_width, cell_height = area_params
+        if visual_kwargs.get("method", "subdivide") != "subdivide":
+            raise ValueError("Only 'subdivide' drawing method is supported.")
+        visual_kwargs["method"] = "subdivide"
+        if "grid" in visual_kwargs:
+            raise ValueError("The 'grid' keyword argument is not supported with the tiled mixin.")
 
         # visual nodes already have names, so be careful
         if not hasattr(self, "name"):
-            self.name = kwargs.get("name", None)
+            self.name = visual_kwargs.pop("name", None)
+
+        self._init_geo_parameters(
+            origin_x,
+            origin_y,
+            cell_width,
+            cell_height,
+            projection,
+            texture_shape,
+            tile_shape,
+            wrap_lon,
+            visual_kwargs.get('shape'),
+            data,
+        )
+
+        # Call the init of the Visual
+        super().__init__(data, **visual_kwargs)
+
+        self.unfreeze()
+        self.overview_info = None
+        self.init_overview(data)
+        self.freeze()
+
+    def _init_geo_parameters(
+            self,
+            origin_x,
+            origin_y,
+            cell_width,
+            cell_height,
+            projection,
+            texture_shape,
+            tile_shape,
+            wrap_lon,
+            shape,
+            data
+    ):
         self._viewable_mesh_mask = None
         self._ref1 = None
         self._ref2 = None
@@ -292,138 +234,18 @@ class TiledGeolocatedImageVisual(ImageVisual):
         # What tiles have we used and can we use
         self.texture_state = TextureTileState(self.num_tex_tiles)
 
-        # load 'float packed rgba8' interpolation kernel
-        # to load float interpolation kernel use
-        # `load_spatial_filters(packed=False)`
-        kernel, self._interpolation_names = load_spatial_filters()
-
-        self._kerneltex = Texture2D(kernel, interpolation='nearest')
-        # The unpacking can be debugged by changing "spatial-filters.frag"
-        # to have the "unpack" function just return the .r component. That
-        # combined with using the below as the _kerneltex allows debugging
-        # of the pipeline
-        # self._kerneltex = Texture2D(kernel, interpolation='linear',
-        #                             internalformat='r32f')
-
-        # create interpolation shader functions for available
-        # interpolations
-        fun = [Function(_interpolation_template % n)
-               for n in self._interpolation_names]
-        self._interpolation_names = [n.lower()
-                                     for n in self._interpolation_names]
-
-        self._interpolation_fun = dict(zip(self._interpolation_names, fun))
-        self._interpolation_names.sort()
-        self._interpolation_names = tuple(self._interpolation_names)
-
-        # overwrite "nearest" and "bilinear" spatial-filters
-        # with  "hardware" interpolation _data_lookup_fn
-        self._interpolation_fun['nearest'] = Function(_texture_lookup)
-        self._interpolation_fun['bilinear'] = Function(_texture_lookup)
-
-        if interpolation not in self._interpolation_names:
-            raise ValueError("interpolation must be one of %s" %
-                             ', '.join(self._interpolation_names))
-
-        self._interpolation = interpolation
-
-        # check texture interpolation
-        if self._interpolation == 'bilinear':
-            texture_interpolation = 'linear'
-        else:
-            texture_interpolation = 'nearest'
-
-        self._method = method
-        self._grid = grid
-        self._need_texture_upload = True
-        self._need_vertex_update = True
-        self._need_colortransform_update = True
-        self._need_interpolation_update = True
-        self._texture = TextureAtlas2D(self.texture_shape, tile_shape=self.tile_shape,
-                                       interpolation=texture_interpolation,
-                                       format="LUMINANCE", internalformat="R32F",
-                                       )
-        self._subdiv_position = VertexBuffer()
-        self._subdiv_texcoord = VertexBuffer()
-
-        # impostor quad covers entire viewport
-        vertices = np.array([[-1, -1], [1, -1], [1, 1],
-                             [-1, -1], [1, 1], [-1, 1]],
-                            dtype=np.float32)
-        self._impostor_coords = VertexBuffer(vertices)
-        self._null_tr = NullTransform()
-
-        self._init_view(self)
-        super(ImageVisual, self).__init__(vcode=VERT_SHADER, fcode=FRAG_SHADER)
-        self.set_gl_state('translucent', cull_face=False)
-        self._draw_mode = 'triangles'
-
-        # define _data_lookup_fn as None, will be setup in
-        # self._build_interpolation()
-        self._data_lookup_fn = None
-
-        self.gamma = gamma
-        self.clim = clim if clim != 'auto' else (np.nanmin(data), np.nanmax(data))
-        self._texture_LUT = None
-        self.cmap = cmap
-
-        self.overview_info = None
-        self.init_overview(data)
-        # self.transform = PROJ4Transform(projection, inverse=True)
-
-        self.freeze()
-
-    @property
-    def gamma(self):
-        return self._gamma
-
-    @gamma.setter
-    def gamma(self, gamma):
-        self._gamma = gamma if gamma is not None else 1.
-        self._need_texture_upload = True
-        self.update()
-
-    # @property
-    # def clim(self):
-    #     return (self._clim if isinstance(self._clim, string_types) else
-    #             tuple(self._clim))
-    #
-    # @clim.setter
-    # def clim(self, clim):
-    #     if isinstance(clim, string_types):
-    #         if clim != 'auto':
-    #             raise ValueError('clim must be "auto" if a string')
-    #     else:
-    #         clim = np.array(clim, float)
-    #         if clim.shape != (2,):
-    #             raise ValueError('clim must have two elements')
-    #     self._clim = clim
-    #     # FIXME: Is this supposed to be assigned to something?:
-    #     self._data_lookup_fn
-    #     self._need_clim_update = True
-    #     self.update()
-
-    @property
-    def size(self):
-        # Added to shader program, but not used by subdivide/tiled method
-        return self.shape[-2:][::-1]
-
     def init_overview(self, data):
         """Create and add a low resolution version of the data that is always
         shown behind the higher resolution image tiles.
         """
-        # FUTURE: Actually use this data attribute. For now let the base
-        #         think there is data (not None)
-        self._data = ArrayProxy(self.ndim, self.shape)
         self.overview_info = nfo = {}
         y_slice, x_slice = self.calc.overview_stride
-        nfo["data"] = data[y_slice, x_slice]
         # Update kwargs to reflect the new spatial resolution of the overview image
         nfo["cell_width"] = self.cell_width * x_slice.step
         nfo["cell_height"] = self.cell_height * y_slice.step
         # Tell the texture state that we are adding a tile that should never expire and should always exist
         nfo["texture_tile_index"] = ttile_idx = self.texture_state.add_tile((0, 0, 0), expires=False)
-        self._texture.set_tile_data(ttile_idx, self._normalize_data(nfo["data"]))
+        self._init_overview_data(ttile_idx, data)
 
         # Handle wrapping around the anti-meridian so there is a -180/180 continuous image
         num_tiles = 1 if not self.wrap_lon else 2
@@ -439,10 +261,13 @@ class TiledGeolocatedImageVisual(ImageVisual):
                                                                                    tessellation_level=TESS_LEVEL)
         self._set_vertex_tiles(nfo["vertex_coordinates"], nfo["texture_coordinates"])
 
+    def _init_overview_data(self, ttile_idx, data):
+        _y_slice, _x_slice = self.calc.calc_overview_stride(image_shape=Point(data.shape[0], data.shape[1]))
+        self._texture.set_tile_data(ttile_idx, self._normalize_data(data[_y_slice, _x_slice]))
+
     def _normalize_data(self, data):
         if data is not None and data.dtype == np.float64:
             data = data.astype(np.float32)
-
         return data
 
     def _build_texture_tiles(self, data, stride, tile_box: Box):
@@ -466,16 +291,19 @@ class TiledGeolocatedImageVisual(ImageVisual):
 
                 # Assume we were given a total image worth of this stride
                 y_slice, x_slice = self.calc.calc_tile_slice(tiy, tix, stride)
-                # force a copy of the data from the content array (provided by the workspace)
-                # to a vispy-compatible contiguous float array
-                # this can be a potentially time-expensive operation since content array is
-                # often huge and always memory-mapped, so paging may occur
-                # we don't want this paging deferred until we're back in the GUI thread pushing data to OpenGL!
-                tile_data = np.array(data[y_slice, x_slice], dtype=np.float32)
+                tile_data = self._slice_texture_tile(data, y_slice, x_slice)
                 tiles_info.append((stride, tiy, tix, tex_tile_idx, tile_data))
 
         return tiles_info
 
+    def _slice_texture_tile(self, data, y_slice, x_slice):
+        # force a copy of the data from the content array (provided by the workspace)
+        # to a vispy-compatible contiguous float array
+        # this can be a potentially time-expensive operation since content array is
+        # often huge and always memory-mapped, so paging may occur
+        # we don't want this paging deferred until we're back in the GUI thread pushing data to OpenGL!
+        return np.array(data[y_slice, x_slice], dtype=np.float32)
+
     def _set_texture_tiles(self, tiles_info):
         for tile_info in tiles_info:
             stride, tiy, tix, tex_tile_idx, data = tile_info
@@ -647,6 +475,24 @@ class TiledGeolocatedImageVisual(ImageVisual):
         self._stride = preferred_stride
         self._latest_tile_box = tile_box
 
+
+class TiledGeolocatedImageVisual(SIFTTiledGeolocatedMixin, ImageVisual):
+    def __init__(self, data, origin_x, origin_y, cell_width, cell_height,
+                 **image_kwargs):
+        super().__init__(data, origin_x, origin_y, cell_width, cell_height, **image_kwargs)
+
+    def _init_texture(self, data, texture_format):
+        if self._interpolation == 'bilinear':
+            texture_interpolation = 'linear'
+        else:
+            texture_interpolation = 'nearest'
+
+        tex = TextureAtlas2D(self.texture_shape, tile_shape=self.tile_shape,
+                             interpolation=texture_interpolation,
+                             format="LUMINANCE", internalformat="R32F",
+                             )
+        return tex
+
     def set_data(self, image):
         """Set the data
 
@@ -655,12 +501,14 @@ class TiledGeolocatedImageVisual(ImageVisual):
         image : array-like
             The image data.
         """
-        raise NotImplementedError("This image subclass does not support the 'set_data' method")
+        if self._data is not None:
+            raise NotImplementedError("This image subclass does not support the 'set_data' method.")
+        # only do this on __init__
+        super().set_data(image)
 
     def _build_texture(self):
         # _build_texture should not be used in this class, use the 2-step
         # process of '_build_texture_tiles' and '_set_texture_tiles'
-        self._set_clim_vars()
         self._need_texture_upload = False
 
     def _build_vertex_data(self):
@@ -668,12 +516,6 @@ class TiledGeolocatedImageVisual(ImageVisual):
         # process of '_build_vertex_tiles' and '_set_vertex_tiles'
         return
 
-    def _set_clim_vars(self):
-        self._data_lookup_fn["vmin"] = self._clim[0]
-        self._data_lookup_fn["vmax"] = self._clim[1]
-        self._data_lookup_fn["gamma"] = self._gamma
-        # self._need_texture_upload = True
-
 
 TiledGeolocatedImage = create_visual_node(TiledGeolocatedImageVisual)
 
@@ -683,191 +525,112 @@ _rgb_texture_lookup = """
         texcoord.y < 0.0 || texcoord.y > 1.0) {
             discard;
         }
-        vec4 val = texture2D($texture, texcoord);
-        // http://stackoverflow.com/questions/11810158/how-to-deal-with-nan-or-inf-in-opengl-es-2-0-shaders
-        if (!(val.r <= 0.0 || 0.0 <= val.r)) {
-            val.r = 0;
-            val.g = 0;
-            val.b = 0;
-            val.a = 0;
-            return val;
-        }
-
-        if ($vmin < $vmax) {
-            val.r = clamp(val.r, $vmin, $vmax);
-        } else {
-            val.r = clamp(val.r, $vmax, $vmin);
-        }
-        val.r = (val.r-$vmin)/($vmax-$vmin);
-        val.r = pow(val.r, $gamma);
-        val.g = val.r;
-        val.b = val.r;
-
+        vec4 val;
+        val.r = texture2D($texture_r, texcoord).r;
+        val.g = texture2D($texture_g, texcoord).r;
+        val.b = texture2D($texture_b, texcoord).r;
+        val.a = 1.0;
         return val;
     }"""
 
+_apply_clim = """
+    vec4 apply_clim(vec4 color) {
+        // If all the pixels are NaN make it completely transparent
+        // http://stackoverflow.com/questions/11810158/how-to-deal-with-nan-or-inf-in-opengl-es-2-0-shaders
+        if (
+            !(color.r <= 0.0 || 0.0 <= color.r) &&
+            !(color.g <= 0.0 || 0.0 <= color.g) &&
+            !(color.b <= 0.0 || 0.0 <= color.b)) {
+            color.a = 0;
+        }
+        
+        // if color is NaN, set to minimum possible value
+        color.r = !(color.r <= 0.0 || 0.0 <= color.r) ? min($clim_r.x, $clim_r.y) : color.r;
+        color.g = !(color.g <= 0.0 || 0.0 <= color.g) ? min($clim_g.x, $clim_g.y) : color.g;
+        color.b = !(color.b <= 0.0 || 0.0 <= color.b) ? min($clim_b.x, $clim_b.y) : color.b;
+        // clamp data to minimum and maximum of clims
+        color.r = clamp(color.r, min($clim_r.x, $clim_r.y), max($clim_r.x, $clim_r.y));
+        color.g = clamp(color.g, min($clim_g.x, $clim_g.y), max($clim_g.x, $clim_g.y));
+        color.b = clamp(color.b, min($clim_b.x, $clim_b.y), max($clim_b.x, $clim_b.y));
+        // linearly scale data between clims
+        color.r = (color.r - $clim_r.x) / ($clim_r.y - $clim_r.x);
+        color.g = (color.g - $clim_g.x) / ($clim_g.y - $clim_g.x);
+        color.b = (color.b - $clim_b.x) / ($clim_b.y - $clim_b.x);
+        return max(color, 0);
+    }
+"""
 
-class CompositeLayerVisual(TiledGeolocatedImageVisual):
-    VERT_SHADER = None
-    FRAG_SHADER = None
-
-    def __init__(self, data_arrays, origin_x, origin_y, cell_width, cell_height,
-                 shape=None,
-                 tile_shape=(DEFAULT_TILE_HEIGHT, DEFAULT_TILE_WIDTH),
-                 texture_shape=(DEFAULT_TEXTURE_HEIGHT, DEFAULT_TEXTURE_WIDTH),
-                 wrap_lon=False,
-                 cmap='viridis', method='tiled', clim='auto', gamma=None,
-                 interpolation='nearest', **kwargs):
-        # projection properties to be filled in later
-        self.cell_width = None
-        self.cell_height = None
-        self.origin_x = None
-        self.origin_y = None
-        self.shape = None
-
-        if method != 'tiled':
-            raise ValueError("Only 'tiled' method is currently supported")
-        method = 'subdivide'
-        grid = (1, 1)
-
-        # visual nodes already have names, so be careful
-        if not hasattr(self, "name"):
-            self.name = kwargs.get("name", None)
-        self._viewable_mesh_mask = None
-        self._ref1 = None
-        self._ref2 = None
-
-        self.texture_shape = texture_shape
-        self.tile_shape = tile_shape
-        self.num_tex_tiles = self.texture_shape[0] * self.texture_shape[1]
-        self._stride = 0  # Current stride is None when we are showing the overview
-        self._latest_tile_box = None
-        self.wrap_lon = wrap_lon
-        self._tiles = {}
+_apply_gamma = """
+    vec4 apply_gamma(vec4 color) {
+        color.r = pow(color.r, $gamma_r);
+        color.g = pow(color.g, $gamma_g);
+        color.b = pow(color.b, $gamma_b);
+        return color;
+    }
+"""
 
-        # What tiles have we used and can we use (each texture uses the same 'state')
-        self.texture_state = TextureTileState(self.num_tex_tiles)
+_null_color_transform = 'vec4 pass(vec4 color) { return color; }'
 
-        self.set_channels(data_arrays, shape=shape,
-                          cell_width=cell_width, cell_height=cell_height,
-                          origin_x=origin_x, origin_y=origin_y)
-        self.ndim = len(self.shape) or [x for x in data_arrays if x is not None][0].ndim
-        self.num_channels = len(data_arrays)
 
-        # load 'float packed rgba8' interpolation kernel
-        # to load float interpolation kernel use
-        # `load_spatial_filters(packed=False)`
-        kernel, self._interpolation_names = load_spatial_filters()
-
-        self._kerneltex = Texture2D(kernel, interpolation='nearest')
-        # The unpacking can be debugged by changing "spatial-filters.frag"
-        # to have the "unpack" function just return the .r component. That
-        # combined with using the below as the _kerneltex allows debugging
-        # of the pipeline
-        # self._kerneltex = Texture2D(kernel, interpolation='linear',
-        #                             internalformat='r32f')
-
-        # create interpolation shader functions for available
-        # interpolations
-        fun = [Function(_interpolation_template % n)
-               for n in self._interpolation_names]
-        self._interpolation_names = [n.lower()
-                                     for n in self._interpolation_names]
-
-        self._interpolation_fun = dict(zip(self._interpolation_names, fun))
-        self._interpolation_names.sort()
-        self._interpolation_names = tuple(self._interpolation_names)
-
-        # overwrite "nearest" and "bilinear" spatial-filters
-        # with  "hardware" interpolation _data_lookup_fn
-        self._interpolation_fun['nearest'] = Function(_texture_lookup)
-        self._interpolation_fun['bilinear'] = Function(_texture_lookup)
-
-        if interpolation not in self._interpolation_names:
-            raise ValueError("interpolation must be one of %s" %
-                             ', '.join(self._interpolation_names))
-
-        self._interpolation = interpolation
-
-        # check texture interpolation
-        if self._interpolation == 'bilinear':
-            texture_interpolation = 'linear'
-        else:
-            texture_interpolation = 'nearest'
+class SIFTMultiChannelTiledGeolocatedMixin(SIFTTiledGeolocatedMixin):
+    def _normalize_data(self, data_arrays):
+        if not isinstance(data_arrays, (list, tuple)):
+            return super()._normalize_data(data_arrays)
 
-        self._method = method
-        self._grid = grid
-        self._need_texture_upload = True
-        self._need_vertex_update = True
-        self._need_colortransform_update = False
-        self._need_interpolation_update = True
-        self._textures = [TextureAtlas2D(self.texture_shape, tile_shape=self.tile_shape,
-                                         interpolation=texture_interpolation,
-                                         format="LUMINANCE", internalformat="R32F",
-                                         ) for i in range(self.num_channels)
-                          ]
-        self._subdiv_position = VertexBuffer()
-        self._subdiv_texcoord = VertexBuffer()
-
-        # impostor quad covers entire viewport
-        vertices = np.array([[-1, -1], [1, -1], [1, 1],
-                             [-1, -1], [1, 1], [-1, 1]],
-                            dtype=np.float32)
-        self._impostor_coords = VertexBuffer(vertices)
-        self._null_tr = NullTransform()
-
-        self._init_view(self)
-        if self.VERT_SHADER is None or self.FRAG_SHADER is None:
-            raise RuntimeError("No shader specified for this subclass")
-        super(ImageVisual, self).__init__(vcode=self.VERT_SHADER, fcode=self.FRAG_SHADER)
-        self.set_gl_state('translucent', cull_face=False)
-        self._draw_mode = 'triangles'
-
-        # define _data_lookup_fn as None, will be setup in
-        # self._build_interpolation()
-        self._data_lookup_fns = [Function(_rgb_texture_lookup) for i in range(self.num_channels)]
-
-        if isinstance(clim, str):
-            if clim != 'auto':
-                raise ValueError("C-limits can only be 'auto' or 2 floats for each provided channel")
-            clim = [clim] * self.num_channels
-        if not isinstance(cmap, (tuple, list)):
-            cmap = [cmap] * self.num_channels
-
-        assert (len(clim) == self.num_channels)
-        assert (len(cmap) == self.num_channels)
-        _clim = []
-        _cmap = []
-        for idx in range(self.num_channels):
-            cl = clim[idx]
-            if cl == 'auto':
-                _clim.append((np.nanmin(data_arrays[idx]), np.nanmax(data_arrays[idx])))
-            elif cl is None:
-                # Color limits don't matter (either empty channel array or other)
-                _clim.append((0., 1.))
-            elif isinstance(cl, tuple) and len(cl) == 2:
-                _clim.append(cl)
-            else:
-                raise ValueError("C-limits must be a 2-element tuple or the string 'auto' for each channel provided")
-
-            cm = cmap[idx]
-            _cmap.append(cm)
-        self.clim = _clim
-        self._texture_LUT = None
-        self.gamma = gamma if gamma is not None else (1.,) * self.num_channels
-        # only set colormap if it isn't None
-        # (useful when a subclass's shader doesn't expect a colormap)
-        if _cmap[0] is not None:
-            self.cmap = _cmap[0]
+        new_data = []
+        for data in data_arrays:
+            new_data.append(super()._normalize_data(data))
+        return new_data
 
-        self.overview_info = None
-        self.init_overview(data_arrays)
+    def _init_overview_data(self, ttile_idx, data_arrays):
+        new_arrays = []
+        for idx, data in enumerate(data_arrays):
+            if data is None:
+                new_arrays.append(None)
+                continue
+            _y_slice, _x_slice = self.calc.calc_overview_stride(image_shape=Point(data.shape[0], data.shape[1]))
+            overview_data = self._normalize_data(data[_y_slice, _x_slice])
+            new_arrays.append(overview_data)
+        self._texture.set_tile_data(ttile_idx, new_arrays)
+
+    def _init_geo_parameters(
+            self,
+            origin_x,
+            origin_y,
+            cell_width,
+            cell_height,
+            projection,
+            texture_shape,
+            tile_shape,
+            wrap_lon,
+            shape,
+            data_arrays
+    ):
+        if shape is None:
+            shape = self._compute_shape(shape, data_arrays)
+        ndim = len(shape) or [x for x in data_arrays if x is not None][0].ndim
+        data = ArrayProxy(ndim, shape)
+        super()._init_geo_parameters(
+            origin_x,
+            origin_y,
+            cell_width,
+            cell_height,
+            projection,
+            texture_shape,
+            tile_shape,
+            wrap_lon,
+            shape,
+            data,
+        )
 
-        self.freeze()
+        self.set_channels(
+            data_arrays, shape=shape, cell_width=cell_width,
+            cell_height=cell_height, origin_x=origin_x, origin_y=origin_y,
+        )
 
     def set_channels(self, data_arrays, shape=None,
                      cell_width=None, cell_height=None,
-                     origin_x=None, origin_y=None, **kwargs):
+                     origin_x=None, origin_y=None):
         assert (shape or data_arrays is not None), "`data` or `shape` must be provided"
         if cell_width is not None:
             self.cell_width = cell_width
@@ -877,7 +640,7 @@ class CompositeLayerVisual(TiledGeolocatedImageVisual):
             self.origin_x = origin_x
         if origin_y:
             self.origin_y = origin_y
-        self.shape = shape or max(data.shape for data in data_arrays if data is not None)
+        self.shape = self._compute_shape(shape, data_arrays)
         assert None not in (self.cell_width, self.cell_height, self.origin_x, self.origin_y, self.shape)
         # how many of the higher resolution channel tiles (smaller geographic area) make
         # up a low resolution channel tile
@@ -907,208 +670,264 @@ class CompositeLayerVisual(TiledGeolocatedImageVisual):
         # even though we might be looking at the exact same spot
         self._latest_tile_box = None
 
-    def init_overview(self, data_arrays):
-        """Create and add a low resolution version of the data that is always
-        shown behind the higher resolution image tiles.
-        """
-        # FUTURE: Actually use this data attribute. For now let the base
-        #         think there is data (not None)
-        self._data = ArrayProxy(self.ndim, self.shape)
-        self.overview_info = nfo = {}
-        y_slice, x_slice = self.calc.overview_stride
-        # Update kwargs to reflect the new spatial resolution of the overview image
-        nfo["cell_width"] = self.cell_width * x_slice.step
-        nfo["cell_height"] = self.cell_height * y_slice.step
-        # Tell the texture state that we are adding a tile that should never expire and should always exist
-        nfo["texture_tile_index"] = ttile_idx = self.texture_state.add_tile((0, 0, 0), expires=False)
-        for idx, data in enumerate(data_arrays):
+    @staticmethod
+    def _compute_shape(shape, data_arrays):
+        return shape or max(data.shape for data in data_arrays if data is not None)
+
+    def _get_stride(self, view_box):
+        s = self.calc.calc_stride(view_box, texture=self._lowest_rez)
+        return Point(np.int64(s[0] * self._lowest_factor), np.int64(s[1] * self._lowest_factor))
+
+    def _slice_texture_tile(self, data_arrays, y_slice, x_slice):
+        new_data = []
+        for data in data_arrays:
             if data is not None:
-                _y_slice, _x_slice = self.calc.calc_overview_stride(image_shape=Point(data.shape[0], data.shape[1]))
-                overview_data = data[_y_slice, _x_slice]
-            else:
-                overview_data = None
-            self._textures[idx].set_tile_data(ttile_idx, self._normalize_data(overview_data))
+                # explicitly ask for the parent class of MultiBandTextureAtlas2D
+                data = super()._slice_texture_tile(data, y_slice, x_slice)
+            new_data.append(data)
+        return new_data
+
+
+class MultiChannelImageVisual(ImageVisual):
+    """Visual subclass displaying an image from three separate arrays.
+
+    Note this Visual uses only GPU scaling, unlike the ImageVisual base
+    class which allows for CPU or GPU scaling.
+
+    Parameters
+    ----------
+    data : list
+        A 3-element list of numpy arrays with 2 dimensons where the
+        arrays are sorted by (R, G, B) order. These will be put together
+        to make an RGB image. The list can contain ``None`` meaning there
+        is no value for this channel currently, but it may be filled in
+        later. In this case the underlying GPU storage is still allocated,
+        but pre-filled with NaNs. Note that each channel may have different
+        shapes.
+    cmap : str | Colormap
+        Unused by this Visual, but is still provided to the ImageVisual base
+        class.
+    clim : str | tuple | list | None
+        Limits of each RGB data array. If provided as a string it must be
+        "auto" and the limits will be computed on the fly. If a 2-element
+        tuple then it will be considered the color limits for all channel
+        arrays. If provided as a 3-element list of 2-element tuples then
+        they represent the color limits of each channel array.
+    gamma : float | list
+        Gamma to use during colormap lookup.  Final value will be computed
+        ``val**gamma` for each RGB channel array. If provided as a float then
+        it will be used for each channel. If provided as a 3-element tuple
+        then each value is used for the separate channel arrays. Default is
+        1.0 for each channel.
+    **kwargs : dict
+        Keyword arguments to pass to :class:`~vispy.visuals.ImageVisual`. Note
+        that this Visual does not allow for ``texture_format`` to be specified
+        and is hardcoded to ``r32f`` internal texture format.
 
-        # Handle wrapping around the anti-meridian so there is a -180/180 continuous image
-        num_tiles = 1 if not self.wrap_lon else 2
-        tl = TESS_LEVEL * TESS_LEVEL
-        nfo["texture_coordinates"] = np.empty((6 * num_tiles * tl, 2), dtype=np.float32)
-        nfo["vertex_coordinates"] = np.empty((6 * num_tiles * tl, 2), dtype=np.float32)
-        factor_rez, offset_rez = self.calc.calc_tile_fraction(
-            0, 0, Point(np.int64(y_slice.step), np.int64(x_slice.step)))
-        nfo["texture_coordinates"][:6 * tl, :2] = self.calc.calc_texture_coordinates(ttile_idx, factor_rez, offset_rez,
-                                                                                     tessellation_level=TESS_LEVEL)
-        nfo["vertex_coordinates"][:6 * tl, :2] = self.calc.calc_vertex_coordinates(0, 0, y_slice.step, x_slice.step,
-                                                                                   factor_rez, offset_rez,
-                                                                                   tessellation_level=TESS_LEVEL)
-        self._set_vertex_tiles(nfo["vertex_coordinates"], nfo["texture_coordinates"])
+    """
 
-    @property
-    def gamma(self):
-        return self._gamma
+    VERTEX_SHADER = ImageVisual.VERTEX_SHADER
+    FRAGMENT_SHADER = ImageVisual.FRAGMENT_SHADER
+
+    def __init__(self, data_arrays, clim='auto', gamma=1.0, **kwargs):
+        if kwargs.get("texture_format") is not None:
+            raise ValueError("'texture_format' can't be specified with the "
+                             "'MultiChannelImageVisual'.")
+        kwargs["texture_format"] = "R32F"
+        if kwargs.get("cmap") is not None:
+            raise ValueError("'cmap' can't be specified with the"
+                             "'MultiChannelImageVisual'.")
+        kwargs["cmap"] = None
+        self.num_channels = len(data_arrays)
+        super().__init__(data_arrays, clim=clim, gamma=gamma, **kwargs)
 
-    @gamma.setter
-    def gamma(self, gamma):
-        assert isinstance(gamma, (tuple, list))
-        assert len(gamma) == self.num_channels
-        self._gamma = tuple(x if x is not None else 1. for x in gamma)
-        self._need_texture_upload = True
-        self.update()
+    def _init_texture(self, data_arrays, texture_format):
+        if self._interpolation == 'bilinear':
+            texture_interpolation = 'linear'
+        else:
+            texture_interpolation = 'nearest'
+
+        tex = MultiChannelGPUScaledTexture2D(
+            data_arrays,
+            internalformat=texture_format,
+            format="LUMINANCE",
+            interpolation=texture_interpolation,
+        )
+        return tex
+
+    def _get_shapes(self, data_arrays):
+        shapes = [x.shape for x in data_arrays if x is not None]
+        if not shapes:
+            raise ValueError("List of data arrays must contain at least one "
+                             "numpy array.")
+        return shapes
+
+    def _get_min_shape(self, data_arrays):
+        return min(self._get_shapes(data_arrays))
+
+    def _get_max_shape(self, data_arrays):
+        return max(self._get_shapes(data_arrays))
+
+    @property
+    def size(self):
+        """Get size of the image (width, height)."""
+        return self._get_max_shape(self._data)
 
     @property
     def clim(self):
-        return (self._clim if isinstance(self._clim, string_types) else
-                tuple(self._clim))
+        """Get color limits used when rendering the image (cmin, cmax)."""
+        return self._texture.clim
 
     @clim.setter
-    def clim(self, clim):
-        if isinstance(clim, string_types):
-            if clim != 'auto':
-                raise ValueError('clim must be "auto" if a string')
-        else:
-            # set clim to 0 and 1 for non-existent arrays
-            clim = [c if c is not None else (0., 1.) for c in clim]
-            clim = np.array(clim, float)
-            if clim.shape != (self.num_channels, 2) and clim.shape != (2,):
-                raise ValueError('clim must have either 2 elements or 6 (2 for each channel)')
-            elif clim.shape == (2,):
-                clim = np.array([clim, clim, clim], float)
-        self._clim = clim
-        self._need_texture_upload = True
+    def clim(self, clims):
+        if isinstance(clims, str) or len(clims) == 2:
+            clims = [clims] * self.num_channels
+        if self._texture.set_clim(clims):
+            self._need_texture_upload = True
+        self._update_colortransform_clim()
         self.update()
 
-    def _set_clim_vars(self):
-        for idx, lookup_fn in enumerate(self._data_lookup_fns):
-            lookup_fn["vmin"] = self._clim[idx, 0]
-            lookup_fn["vmax"] = self._clim[idx, 1]
-            lookup_fn["gamma"] = self._gamma[idx]
-
-    def _build_interpolation(self):
-        # assumes 'nearest' interpolation
-        for idx, lookup_fn in enumerate(self._data_lookup_fns):
-            self.shared_program.frag['get_data_%d' % (idx + 1,)] = lookup_fn
-            lookup_fn['texture'] = self._textures[idx]
-        self._need_interpolation_update = False
-
-    def _build_texture_tiles(self, data, stride, tile_box):
-        """Prepare and organize strided data in to individual tiles with associated information.
-        """
-        data = [self._normalize_data(d) for d in data]
+    def _update_colortransform_clim(self):
+        if self._need_colortransform_update:
+            # we are going to rebuild anyway so just do it later
+            return
+        try:
+            norm_clims = self._texture.clim_normalized
+        except RuntimeError:
+            return
+        else:
+            clim_names = ('clim_r', 'clim_g', 'clim_b')
+            # shortcut so we don't have to rebuild the whole color transform
+            for clim_name, clim in zip(clim_names, norm_clims):
+                # shortcut so we don't have to rebuild the whole color transform
+                self.shared_program.frag['color_transform'][1][clim_name] = clim
 
-        LOG.debug("Uploading texture data for %d tiles (%r)",
-                  (tile_box.bottom - tile_box.top) * (tile_box.right - tile_box.left), tile_box)
-        # Tiles start at upper-left so go from top to bottom
-        tiles_info = []
-        for tiy in range(tile_box.top, tile_box.bottom):
-            for tix in range(tile_box.left, tile_box.right):
-                already_in = (stride, tiy, tix) in self.texture_state
-                # Update the age if already in there
-                # Assume that texture_state does not change from the main thread if this is run in another
-                tex_tile_idx = self.texture_state.add_tile((stride, tiy, tix))
-                if already_in:
-                    # FIXME: we should make a list/set of the tiles we need to add before this
-                    continue
+    @property
+    def gamma(self):
+        """Get the gamma used when rendering the image."""
+        return self._gamma
 
-                # Assume we were given a total image worth of this stride
-                y_slice, x_slice = self.calc.calc_tile_slice(tiy, tix, tuple(stride))
-                textures_data = []
-                for chn_idx in range(self.num_channels):
-                    # force a copy of the data from the content array (provided by the workspace)
-                    # to a vispy-compatible contiguous float array
-                    # this can be a potentially time-expensive operation since content array is often huge and
-                    # always memory-mapped, so paging may occur
-                    # we don't want this paging deferred until we're back in the GUI thread pushing data to OpenGL!
-                    if data[chn_idx] is None:
-                        # we need to fill the texture with NaNs instead of actual data
-                        tile_data = None
-                    else:
-                        tile_data = np.array(data[chn_idx][y_slice, x_slice], dtype=np.float32)
-                    textures_data.append(tile_data)
-                tiles_info.append((stride, tiy, tix, tex_tile_idx, textures_data))
+    @gamma.setter
+    def gamma(self, value):
+        """Set gamma used when rendering the image."""
+        if not isinstance(value, (list, tuple)):
+            value = [value] * self.num_channels
+        if any(val <= 0 for val in value):
+            raise ValueError("gamma must be > 0")
+        self._gamma = tuple(float(x) for x in value)
+
+        gamma_names = ('gamma_r', 'gamma_g', 'gamma_b')
+        for gamma_name, gam in zip(gamma_names, self._gamma):
+            # shortcut so we don't have to rebuild the color transform
+            if not self._need_colortransform_update:
+                self.shared_program.frag['color_transform'][2][gamma_name] = gam
+        self.update()
 
-        return tiles_info
+    @ImageVisual.cmap.setter
+    def cmap(self, cmap):
+        if cmap is not None:
+            raise ValueError("MultiChannelImageVisual does not support a colormap.")
+        self._cmap = None
 
-    def _set_texture_tiles(self, tiles_info):
-        for tile_info in tiles_info:
-            stride, tiy, tix, tex_tile_idx, data_arrays = tile_info
-            for idx, data in enumerate(data_arrays):
-                self._textures[idx].set_tile_data(tex_tile_idx, data)
+    def _build_interpolation(self):
+        # assumes 'nearest' interpolation
+        interpolation = self._interpolation
+        if interpolation != 'nearest':
+            raise NotImplementedError("MultiChannelImageVisual only supports 'nearest' interpolation.")
+        texture_interpolation = 'nearest'
+
+        self._data_lookup_fn = Function(_rgb_texture_lookup)
+        self.shared_program.frag['get_data'] = self._data_lookup_fn
+        if self._texture.interpolation != texture_interpolation:
+            self._texture.interpolation = texture_interpolation
+        self._data_lookup_fn['texture_r'] = self._texture.textures[0]
+        self._data_lookup_fn['texture_g'] = self._texture.textures[1]
+        self._data_lookup_fn['texture_b'] = self._texture.textures[2]
 
-    def _get_stride(self, view_box):
-        s = self.calc.calc_stride(view_box, texture=self._lowest_rez)
-        return Point(np.int64(s[0] * self._lowest_factor), np.int64(s[1] * self._lowest_factor))
+        self._need_interpolation_update = False
 
+    def _build_color_transform(self):
+        if self.num_channels != 3:
+            raise NotImplementedError("MultiChannelimageVisuals only support 3 channels.")
+        else:
+            # RGB/A image data (no colormap)
+            fclim = Function(_apply_clim)
+            fgamma = Function(_apply_gamma)
+            fun = FunctionChain(None, [Function(_null_color_transform), fclim, fgamma])
+        fclim['clim_r'] = self._texture.textures[0].clim_normalized
+        fclim['clim_g'] = self._texture.textures[1].clim_normalized
+        fclim['clim_b'] = self._texture.textures[2].clim_normalized
+        fgamma['gamma_r'] = self.gamma[0]
+        fgamma['gamma_g'] = self.gamma[1]
+        fgamma['gamma_b'] = self.gamma[2]
+        return fun
+
+    def set_data(self, data_arrays):
+        """Set the data
 
-CompositeLayer = create_visual_node(CompositeLayerVisual)
+        Parameters
+        ----------
+        image : array-like
+            The image data.
+        """
+        if self._data is not None and any(self._shape_differs(x1, x2) for x1, x2 in zip(self._data, data_arrays)):
+            self._need_vertex_update = True
+        data_arrays = list(self._cast_arrays_if_needed(data_arrays))
+        self._texture.check_data_format(data_arrays)
+        self._data = data_arrays
+        self._need_texture_upload = True
 
-RGB_VERT_SHADER = """
-uniform int method;  // 0=subdivide, 1=impostor
-attribute vec2 a_position;
-attribute vec2 a_texcoord;
-varying vec2 v_texcoord;
+    @staticmethod
+    def _cast_arrays_if_needed(data_arrays):
+        for data in data_arrays:
+            if data is not None and should_cast_to_f32(data.dtype):
+                data = data.astype(np.float32)
+            yield data
+
+    @staticmethod
+    def _shape_differs(arr1, arr2):
+        none_change1 = arr1 is not None and arr2 is None
+        none_change2 = arr1 is None and arr2 is not None
+        shape_change = False
+        if arr1 is not None and arr2 is not None:
+            shape_change = arr1.shape[:2] != arr2.shape[:2]
+        return none_change1 or none_change2 or shape_change
 
-void main() {
-    v_texcoord = a_texcoord;
-    gl_Position = $transform(vec4(a_position, 0., 1.));
-}
-"""
+    def _build_texture(self):
+        pre_clims = self._texture.clim
+        pre_internalformat = self._texture.internalformat
+        self._texture.scale_and_set_data(self._data)
+        post_clims = self._texture.clim
+        post_internalformat = self._texture.internalformat
+        # color transform needs rebuilding if the internalformat was changed
+        # new color limits need to be assigned if the normalized clims changed
+        # otherwise, the original color transform should be fine
+        # Note that this assumes that if clim changed, clim_normalized changed
+        new_if = post_internalformat != pre_internalformat
+        new_cl = post_clims != pre_clims
+        if new_if or new_cl:
+            self._need_colortransform_update = True
+        self._need_texture_upload = False
 
-RGB_FRAG_SHADER = """
-uniform vec2 image_size;
-uniform int method;  // 0=subdivide, 1=impostor
-uniform sampler2D u_texture;
-varying vec2 v_texcoord;
-
-vec4 map_local_to_tex(vec4 x) {
-    // Cast ray from 3D viewport to surface of image
-    // (if $transform does not affect z values, then this
-    // can be optimized as simply $transform.map(x) )
-    vec4 p1 = $transform(x);
-    vec4 p2 = $transform(x + vec4(0, 0, 0.5, 0));
-    p1 /= p1.w;
-    p2 /= p2.w;
-    vec4 d = p2 - p1;
-    float f = p2.z / d.z;
-    vec4 p3 = p2 - d * f;
-
-    // finally map local to texture coords
-    return vec4(p3.xy / image_size, 0, 1);
-}
-
-
-void main()
-{
-    vec2 texcoord;
-    if( method == 0 ) {
-        texcoord = v_texcoord;
-    }
-    else {
-        // vertex shader ouptuts clip coordinates;
-        // fragment shader maps to texture coordinates
-        texcoord = map_local_to_tex(vec4(v_texcoord, 0, 1)).xy;
-    }
 
-    vec4 r_tmp, g_tmp, b_tmp;
-    r_tmp = $get_data_1(texcoord);
-    g_tmp = $get_data_2(texcoord);
-    b_tmp = $get_data_3(texcoord);
+MultiChannelImage = create_visual_node(MultiChannelImageVisual)
 
-    // Make the pixel transparent if all of the values are NaN/fill values
-    if (r_tmp.a == 0 && g_tmp.a == 0 && b_tmp.a == 0) {
-        gl_FragColor.a = 0;
-    } else {
-        gl_FragColor.a = 1;
-    }
-    gl_FragColor.r = r_tmp.r;
-    gl_FragColor.g = g_tmp.r;
-    gl_FragColor.b = b_tmp.r;
-}
-"""  # noqa
 
+class RGBCompositeLayerVisual(SIFTMultiChannelTiledGeolocatedMixin,
+                              TiledGeolocatedImageVisual,
+                              MultiChannelImageVisual):
+    def _init_texture(self, data_arrays, texture_format):
+        if self._interpolation == 'bilinear':
+            texture_interpolation = 'linear'
+        else:
+            texture_interpolation = 'nearest'
 
-class RGBCompositeLayerVisual(CompositeLayerVisual):
-    VERT_SHADER = RGB_VERT_SHADER
-    FRAG_SHADER = RGB_FRAG_SHADER
+        tex_shapes = [self.texture_shape] * len(data_arrays)
+        tex = MultiChannelTextureAtlas2D(
+            tex_shapes, tile_shape=self.tile_shape,
+            interpolation=texture_interpolation, format="LUMINANCE", internalformat="R32F"
+        )
+        return tex
 
 
 RGBCompositeLayer = create_visual_node(RGBCompositeLayerVisual)