geo-fastai/geofastai/raster.py

40 lines
1.1 KiB
Python

import fastai.vision.all as fai
import rasterio as rio
import warnings
warnings.filterwarnings('ignore', 'Dataset has no geotransform', module='rasterio')
class RasterBase(fai.TensorImage):
@classmethod
def create(cls, fname, **kwargs):
data = rio.open(fname).read()
return cls(data)
class RasterVNIR(RasterBase):
def show(self, ctx=None, bands=[3,2,1], **kwargs):
im = self[bands]
im = im.to(fai.torch.float)
im = im / im.quantile(.9)
im = im.clip(0, 1)
return fai.show_image(im, ctx=ctx, vmax=300, **kwargs)
class RasterToFloatTensor(fai.DisplayedTransform):
order = 10
def __init__(self, div=None):
if div is None:
div = np.iinfo(np.uint16).max
self.div = div
def encodes(self, o:fai.TensorImage):
return o.float().div_(self.div)
def decodes(self, o:fai.TensorImage):
return (o.clamp(0., 1.) * self.div).long()
def RasterBlock(cls=RasterBase):
return fai.TransformBlock(type_tfms=cls.create)#, batch_tfms=RasterToFloatTensor)