40 lines
1.1 KiB
Python
40 lines
1.1 KiB
Python
import fastai.vision.all as fai
|
|
import rasterio as rio
|
|
import warnings
|
|
|
|
warnings.filterwarnings('ignore', 'Dataset has no geotransform', module='rasterio')
|
|
|
|
|
|
class RasterBase(fai.TensorImage):
|
|
@classmethod
|
|
def create(cls, fname, **kwargs):
|
|
data = rio.open(fname).read()
|
|
return cls(data)
|
|
|
|
|
|
class RasterVNIR(RasterBase):
|
|
def show(self, ctx=None, bands=[3,2,1], **kwargs):
|
|
im = self[bands]
|
|
im = im.to(fai.torch.float)
|
|
im = im / im.quantile(.9)
|
|
im = im.clip(0, 1)
|
|
return fai.show_image(im, ctx=ctx, vmax=300, **kwargs)
|
|
|
|
|
|
class RasterToFloatTensor(fai.DisplayedTransform):
|
|
order = 10
|
|
def __init__(self, div=None):
|
|
if div is None:
|
|
div = np.iinfo(np.uint16).max
|
|
|
|
self.div = div
|
|
|
|
def encodes(self, o:fai.TensorImage):
|
|
return o.float().div_(self.div)
|
|
|
|
def decodes(self, o:fai.TensorImage):
|
|
return (o.clamp(0., 1.) * self.div).long()
|
|
|
|
|
|
def RasterBlock(cls=RasterBase):
|
|
return fai.TransformBlock(type_tfms=cls.create)#, batch_tfms=RasterToFloatTensor) |