Coverage for gws-app/gws/gis/gdalx/__init__.py: 0%
333 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-17 01:37 +0200
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-17 01:37 +0200
1"""GDAL/OGR wrapper."""
3from typing import Optional, cast
5import contextlib
7from osgeo import gdal
8from osgeo import ogr
9from osgeo import osr
11import gws
12import gws.base.shape
13import gws.gis.crs
14import gws.lib.datetimex as datetimex
17class Error(gws.Error):
18 pass
21class DriverInfo(gws.Data):
22 index: int
23 name: str
24 longName: str
25 metaData: dict
28class _DriverInfoCache(gws.Data):
29 infos: list[DriverInfo]
30 extToName: dict
31 vectorNames: set[str]
32 rasterNames: set[str]
35def drivers() -> list[DriverInfo]:
36 """Enumerate GDAL drivers."""
38 di = _fetch_driver_infos()
39 return di.infos
42def open_raster(
43 path: str,
44 mode: str = 'r',
45 driver: str = '',
46 default_crs: Optional[gws.Crs] = None,
47 **opts
48) -> 'RasterDataSet':
49 """Create a raster DataSet from a path.
51 Args:
52 path: File path.
53 mode: 'r' (=read), 'a' (=update), 'w' (=create/write)
54 driver: Driver name, if omitted, will be suggested from the path extension.
55 default_crs: Default CRS for geometries (fallback to Webmercator).
56 opts: Options for gdal.OpenEx/CreateDataSource.
57 """
59 return _open(path, mode, driver, None, default_crs, opts, True)
62def open_vector(
63 path: str,
64 mode: str = 'r',
65 driver: str = '',
66 encoding: str = 'utf8',
67 default_crs: Optional[gws.Crs] = None,
68 **opts
69) -> 'VectorDataSet':
70 """Create a vector DataSet from a path.
72 Args:
73 path: File path.
74 mode: 'r' (=read), 'a' (=update), 'w' (=create/write)
75 driver: Driver name, if omitted, will be suggested from the path extension.
76 encoding: If not None, strings will be automatically decoded.
77 default_crs: Default CRS for geometries (fallback to Webmercator).
78 opts: Options for gdal.OpenEx/CreateDataSource.
80 Returns:
81 DataSet object.
83 """
85 return _open(path, mode, driver, encoding, default_crs, opts, False)
88def _open(path, mode, driver, encoding, default_crs, opts, need_raster):
89 if not mode:
90 mode = 'r'
91 if mode not in 'rwa':
92 raise Error(f'invalid open mode {mode!r}')
94 gdal.UseExceptions()
96 drv = _driver_from_args(path, driver, need_raster)
97 default_crs = default_crs or gws.gis.crs.WEBMERCATOR
99 if mode == 'w':
100 gd = drv.CreateDataSource(path, **opts)
101 if gd is None:
102 raise Error(f'cannot create {path!r}')
103 if need_raster:
104 return RasterDataSet(path, gd, encoding, default_crs)
105 return VectorDataSet(path, gd, encoding, default_crs)
107 if not gws.u.is_file(path):
108 raise Error(f'file not found {path!r}')
110 flags = gdal.OF_VERBOSE_ERROR
111 if mode == 'r':
112 flags += gdal.OF_READONLY
113 if mode == 'a':
114 flags += gdal.OF_UPDATE
115 if need_raster:
116 flags += gdal.OF_RASTER
117 else:
118 flags += gdal.OF_VECTOR
120 gd = gdal.OpenEx(path, flags, **opts)
121 if gd is None:
122 raise Error(f'cannot open {path!r}')
124 if need_raster:
125 return RasterDataSet(path, gd, encoding, default_crs)
126 return VectorDataSet(path, gd, encoding, default_crs)
129def open_from_image(image: gws.Image, bounds: gws.Bounds) -> 'RasterDataSet':
130 """Create an in-memory Dataset from an Image.
132 Args:
133 image: Image object
134 bounds: geographic bounds
135 """
137 gdal.UseExceptions()
139 drv = gdal.GetDriverByName('MEM')
140 img_array = image.to_array()
141 band_count = img_array.shape[2]
143 gd = drv.Create('', img_array.shape[1], img_array.shape[0], band_count, gdal.GDT_Byte)
144 for band in range(band_count):
145 gd.GetRasterBand(band + 1).WriteArray(img_array[:, :, band])
147 ext = bounds.extent
149 src_res_x = (ext[2] - ext[0]) / gd.RasterXSize
150 src_res_y = (ext[1] - ext[3]) / gd.RasterYSize
152 src_transform = (
153 ext[0],
154 src_res_x,
155 0,
156 ext[3],
157 0,
158 src_res_y,
159 )
161 gd.SetGeoTransform(src_transform)
162 gd.SetSpatialRef(_srs_from_srid(bounds.crs.srid))
164 return RasterDataSet('', gd)
167##
169class _DataSet:
170 gdDataset: gdal.Dataset
171 gdDriver: gdal.Driver
172 path: str
173 driverName: str
174 defaultCrs: gws.Crs
175 encoding: str
177 def __init__(self, path, gd_dataset, encoding='utf8', default_crs: Optional[gws.Crs] = None):
178 self.path = path
179 self.gdDataset = gd_dataset
180 self.gdDriver = self.gdDataset.GetDriver()
181 self.driverName = self.gdDriver.GetDescription()
182 self.encoding = encoding
183 self.defaultCrs = default_crs or gws.gis.crs.WEBMERCATOR
185 def __enter__(self):
186 return self
188 def __exit__(self, exc_type, exc_val, exc_tb):
189 self.close()
190 return False
192 def close(self):
193 self.gdDataset.FlushCache()
194 setattr(self, 'gdDataset', None)
197class RasterDataSet(_DataSet):
198 def create_copy(self, path: str, driver: str = '', strict=False, **opts):
199 """Create a copy of a DataSet."""
201 gdal.UseExceptions()
203 drv = _driver_from_args(path, driver, need_raster=True)
204 gd = drv.CreateCopy(path, self.gdDataset, 1 if strict else 0, **opts)
205 gd.SetMetadata(self.gdDataset.GetMetadata())
206 gd.FlushCache()
207 gd = None
210class VectorDataSet(_DataSet):
211 @contextlib.contextmanager
212 def transaction(self):
213 self.gdDataset.StartTransaction()
214 try:
215 yield self
216 self.gdDataset.CommitTransaction()
217 except:
218 self.gdDataset.RollbackTransaction()
219 raise
221 def create_layer(
222 self,
223 name: str,
224 columns: dict[str, gws.AttributeType],
225 geometry_type: gws.GeometryType = None,
226 crs: gws.Crs = None,
227 overwrite=False,
228 *options,
229 ) -> 'VectorLayer':
230 opts = list(options)
231 if overwrite:
232 opts.append('OVERWRITE=YES')
234 geom_type = ogr.wkbUnknown
235 srs = None
237 if geometry_type:
238 geom_type = _GEOM_TO_OGR.get(geometry_type)
239 if not geom_type:
240 gws.log.warning(f'gdal: unsupported {geometry_type=}')
241 geom_type = ogr.wkbUnknown
242 crs = crs or self.defaultCrs
243 srs = _srs_from_srid(crs.srid)
245 gd_layer = self.gdDataset.CreateLayer(
246 name,
247 geom_type=geom_type,
248 srs=srs,
249 options=opts,
250 )
251 for col_name, col_type in columns.items():
252 gd_layer.CreateField(ogr.FieldDefn(col_name, _ATTR_TO_OGR[col_type]))
254 return VectorLayer(self, gd_layer)
256 def layers(self) -> list['VectorLayer']:
257 cnt = self.gdDataset.GetLayerCount()
258 return [VectorLayer(self, self.gdDataset.GetLayerByIndex(n)) for n in range(cnt)]
260 def layer(self, name_or_index: str | int) -> Optional['VectorLayer']:
261 gd_layer = self.gdDataset.GetLayer(name_or_index)
262 return VectorLayer(self, gd_layer) if gd_layer else None
265class VectorLayer:
266 name: str
267 gdLayer: ogr.Layer
268 gdDefn: ogr.FeatureDefn
269 encoding: str
270 defaultCrs: gws.Crs
272 def __init__(self, ds: VectorDataSet, gd_layer: ogr.Layer):
273 self.gdLayer = gd_layer
274 self.gdDefn = self.gdLayer.GetLayerDefn()
275 self.name = self.gdDefn.GetName()
276 self.encoding = ds.encoding
277 self.defaultCrs = ds.defaultCrs
279 def describe(self) -> gws.DataSetDescription:
280 desc = gws.DataSetDescription(
281 columns=[],
282 columnMap={},
283 fullName=self.name,
284 geometryName='',
285 geometrySrid=0,
286 geometryType='',
287 name=self.name,
288 schema='',
289 )
291 cols = []
293 fid_col = self.gdLayer.GetFIDColumn()
294 if fid_col:
295 cols.append(gws.ColumnDescription(
296 name=fid_col,
297 type=_OGR_TO_ATTR[ogr.OFTInteger],
298 nativeType=ogr.OFTInteger,
299 isPrimaryKey=True,
300 columnIndex=0,
301 ))
303 for i in range(self.gdDefn.GetFieldCount()):
304 fdef: ogr.FieldDefn = self.gdDefn.GetFieldDefn(i)
305 typ = fdef.GetType()
306 if typ not in _OGR_TO_ATTR:
307 continue
308 cols.append(gws.ColumnDescription(
309 name=fdef.GetName(),
310 type=_OGR_TO_ATTR[typ],
311 nativeType=typ,
312 columnIndex=i,
313 ))
315 for i in range(self.gdDefn.GetGeomFieldCount()):
316 fdef: ogr.GeomFieldDefn = self.gdDefn.GetGeomFieldDefn(i)
317 crs: osr.SpatialReference = fdef.GetSpatialRef()
318 typ = fdef.GetType()
319 if typ not in _OGR_TO_GEOM:
320 continue
321 cols.append(gws.ColumnDescription(
322 name=fdef.GetName(),
323 type=gws.AttributeType.geometry,
324 nativeType=typ,
325 columnIndex=i,
326 geometryType=_OGR_TO_GEOM[typ],
327 geometrySrid=int(crs.GetAuthorityCode(None)),
328 ))
330 desc.columns = cols
331 desc.columnMap = {c.name: c for c in cols}
333 for c in cols:
334 # NB take the last geom
335 if c.geometryType:
336 desc.geometryName = c.name
337 desc.geometryType = c.geometryType
338 desc.geometrySrid = c.geometrySrid
340 return desc
342 def insert(self, records: list[gws.FeatureRecord]) -> list[int]:
343 desc = self.describe()
344 fids = []
346 for rec in records:
347 gd_feature = ogr.Feature(self.gdDefn)
348 if desc.geometryType and rec.shape:
349 gd_feature.SetGeometry(
350 ogr.CreateGeometryFromWkt(
351 rec.shape.to_wkt(),
352 _srs_from_srid(rec.shape.crs.srid)
353 ))
355 if rec.uid and isinstance(rec.uid, int):
356 gd_feature.SetFID(rec.uid)
358 for col in desc.columns:
359 if col.geometryType or col.isPrimaryKey:
360 continue
361 val = rec.attributes.get(col.name)
362 if val is None:
363 continue
364 try:
365 _attr_to_ogr(gd_feature, int(col.nativeType), col.columnIndex, val, self.encoding)
366 except Exception as exc:
367 raise Error(f'field cannot be set: {col.name=} {val=}') from exc
369 self.gdLayer.CreateFeature(gd_feature)
370 fids.append(gd_feature.GetFID())
372 return fids
374 def count(self, force=False):
375 return self.gdLayer.GetFeatureCount(force=1 if force else 0)
377 def get_all(self) -> list[gws.FeatureRecord]:
378 records = []
380 self.gdLayer.ResetReading()
382 while True:
383 gd_feature = self.gdLayer.GetNextFeature()
384 if not gd_feature:
385 break
386 records.append(self._feature_record(gd_feature))
388 return records
390 def get(self, fid: int) -> Optional[gws.FeatureRecord]:
391 gd_feature = self.gdLayer.GetFeature(fid)
392 if gd_feature:
393 return self._feature_record(gd_feature)
395 def _feature_record(self, gd_feature):
396 rec = gws.FeatureRecord(
397 attributes={},
398 shape=None,
399 meta={'layerName': self.name},
400 uid=str(gd_feature.GetFID()),
401 )
403 for i in range(gd_feature.GetFieldCount()):
404 gd_field_defn: ogr.FieldDefn = gd_feature.GetFieldDefnRef(i)
405 name = gd_field_defn.GetName()
406 val = _attr_from_ogr(gd_feature, gd_field_defn.type, i, self.encoding)
407 rec.attributes[name] = val
409 cnt = gd_feature.GetGeomFieldCount()
410 if cnt > 0:
411 # NB take the last geom
412 # @TODO multigeometry support
413 gd_geom_defn = gd_feature.GetGeomFieldRef(cnt - 1)
414 if gd_geom_defn:
415 srs = gd_geom_defn.GetSpatialReference()
416 if srs:
417 srid = srs.GetAuthorityCode(None)
418 crs = gws.gis.crs.get(srid)
419 else:
420 crs = self.defaultCrs
421 wkt = gd_geom_defn.ExportToWkt()
422 rec.shape = gws.base.shape.from_wkt(wkt, crs)
424 return rec
427##
430def _driver_from_args(path, driver_name, need_raster):
431 di = _fetch_driver_infos()
433 if not driver_name:
434 ext = path.split('.')[-1]
435 names = di.extToName.get(ext)
436 if not names:
437 raise Error(f'no default driver found for {path!r}')
438 if len(names) > 1:
439 raise Error(f'multiple drivers found for {path!r}: {names}')
440 driver_name = names[0]
442 is_vector = driver_name in di.vectorNames
443 is_raster = driver_name in di.rasterNames
445 if need_raster:
446 if not is_raster:
447 raise Error(f'driver {driver_name!r} is not raster')
448 return gdal.GetDriverByName(driver_name)
450 if not is_vector:
451 raise Error(f'driver {driver_name!r} is not vector')
452 return ogr.GetDriverByName(driver_name)
455_di_cache: Optional[_DriverInfoCache] = None
458def _fetch_driver_infos() -> _DriverInfoCache:
459 global _di_cache
461 if _di_cache:
462 return _di_cache
464 _di_cache = _DriverInfoCache(
465 infos=[],
466 extToName={},
467 vectorNames=set(),
468 rasterNames=set(),
469 )
471 for n in range(gdal.GetDriverCount()):
472 drv = gdal.GetDriver(n)
473 inf = DriverInfo(
474 index=n,
475 name=str(drv.ShortName),
476 longName=str(drv.LongName),
477 metaData=dict(drv.GetMetadata() or {})
478 )
479 _di_cache.infos.append(inf)
481 for e in inf.metaData.get(gdal.DMD_EXTENSIONS, '').split():
482 _di_cache.extToName.setdefault(e, []).append(inf.name)
483 if inf.metaData.get('DCAP_VECTOR') == 'YES':
484 _di_cache.vectorNames.add(inf.name)
485 if inf.metaData.get('DCAP_RASTER') == 'YES':
486 _di_cache.rasterNames.add(inf.name)
488 return _di_cache
491_srs_cache = {}
494def _srs_from_srid(srid):
495 if srid not in _srs_cache:
496 _srs_cache[srid] = osr.SpatialReference()
497 _srs_cache[srid].ImportFromEPSG(srid)
498 return _srs_cache[srid]
501def _attr_from_ogr(gd_feature: ogr.Feature, gtype: int, idx: int, encoding: str = 'utf8'):
502 if gd_feature.IsFieldNull(idx):
503 return None
505 if gtype == ogr.OFTString:
506 b = gd_feature.GetFieldAsBinary(idx)
507 if encoding:
508 return b.decode(encoding)
509 return b
511 if gtype in {ogr.OFTDate, ogr.OFTTime, ogr.OFTDateTime}:
512 # python GetFieldAsDateTime appears to use float seconds, as in
513 # GetFieldAsDateTime (int i, int *pnYear, int *pnMonth, int *pnDay, int *pnHour, int *pnMinute, float *pfSecond, int *pnTZFlag)
514 #
515 v = gd_feature.GetFieldAsDateTime(idx)
516 sec, fsec = divmod(v[5], 1)
517 try:
518 return datetimex.new(v[0], v[1], v[2], v[3], v[4], int(sec), int(fsec * 1e6), tz=_tzflag_to_tz(v[6]))
519 except ValueError:
520 return
522 if gtype == ogr.OFSTBoolean:
523 return gd_feature.GetFieldAsInteger(idx) != 0
524 if gtype in {ogr.OFTInteger, ogr.OFTInteger64}:
525 return gd_feature.GetFieldAsInteger(idx)
526 if gtype in {ogr.OFTIntegerList, ogr.OFTInteger64List}:
527 return gd_feature.GetFieldAsIntegerList(idx)
528 if gtype in {ogr.OFTReal, ogr.OFSTFloat32}:
529 return gd_feature.GetFieldAsDouble(idx)
530 if gtype == ogr.OFTRealList:
531 return gd_feature.GetFieldAsDoubleList(idx)
532 if gtype == ogr.OFTBinary:
533 return gd_feature.GetFieldAsBinary(idx)
536def _tzflag_to_tz(tzflag):
537 # see gdal/ogr/ogrutils.cpp OGRGetISO8601DateTime
539 if tzflag == 0 or tzflag == 1:
540 return ''
541 if tzflag == 100:
542 return 'UTC'
543 if tzflag % 4 != 0:
544 # @TODO
545 raise Error(f'unsupported timezone {tzflag=}')
546 hrs = (100 - tzflag) // 4
547 return f'Etc/GMT{hrs:+}'
550def _attr_to_ogr(gd_feature: ogr.Feature, gtype: int, idx: int, value, encoding):
551 if gtype == ogr.OFTDate:
552 return gd_feature.SetField(idx, datetimex.to_iso_date_string(value))
553 if gtype == ogr.OFTTime:
554 return gd_feature.SetField(idx, datetimex.to_iso_time_string(value))
555 if gtype == ogr.OFTDateTime:
556 return gd_feature.SetField(idx, datetimex.to_iso_string(value))
557 if gtype == ogr.OFSTBoolean:
558 return gd_feature.SetField(idx, bool(value))
559 if gtype in {ogr.OFTInteger, ogr.OFTInteger64}:
560 return gd_feature.SetField(idx, int(value))
561 if gtype in {ogr.OFTIntegerList, ogr.OFTInteger64List}:
562 return gd_feature.SetField(idx, [int(x) for x in value])
563 if gtype in {ogr.OFTReal, ogr.OFSTFloat32}:
564 return gd_feature.SetField(idx, float(value))
565 if gtype == ogr.OFTRealList:
566 return gd_feature.SetField(idx, [float(x) for x in value])
568 return gd_feature.SetField(idx, value)
571def is_attribute_supported(typ):
572 return typ in _ATTR_TO_OGR
575_ATTR_TO_OGR = {
576 gws.AttributeType.bool: ogr.OFTInteger,
577 gws.AttributeType.bytes: ogr.OFTBinary,
578 gws.AttributeType.date: ogr.OFTDate,
579 gws.AttributeType.datetime: ogr.OFTDateTime,
580 gws.AttributeType.float: ogr.OFTReal,
581 gws.AttributeType.floatlist: ogr.OFTRealList,
582 gws.AttributeType.int: ogr.OFTInteger,
583 gws.AttributeType.intlist: ogr.OFTIntegerList,
584 gws.AttributeType.str: ogr.OFTString,
585 gws.AttributeType.strlist: ogr.OFTStringList,
586 gws.AttributeType.time: ogr.OFTTime,
587}
589_OGR_TO_ATTR = {
590 ogr.OFTBinary: gws.AttributeType.bytes,
591 ogr.OFTDate: gws.AttributeType.date,
592 ogr.OFTDateTime: gws.AttributeType.datetime,
593 ogr.OFTReal: gws.AttributeType.float,
594 ogr.OFTRealList: gws.AttributeType.floatlist,
595 ogr.OFTInteger: gws.AttributeType.int,
596 ogr.OFTIntegerList: gws.AttributeType.intlist,
597 ogr.OFTInteger64: gws.AttributeType.int,
598 ogr.OFTInteger64List: gws.AttributeType.intlist,
599 ogr.OFTString: gws.AttributeType.str,
600 ogr.OFTStringList: gws.AttributeType.strlist,
601 ogr.OFTTime: gws.AttributeType.time,
602}
604_GEOM_TO_OGR = {
605 gws.GeometryType.curve: ogr.wkbCurve,
606 gws.GeometryType.geometrycollection: ogr.wkbGeometryCollection,
607 gws.GeometryType.linestring: ogr.wkbLineString,
608 gws.GeometryType.multicurve: ogr.wkbMultiCurve,
609 gws.GeometryType.multilinestring: ogr.wkbMultiLineString,
610 gws.GeometryType.multipoint: ogr.wkbMultiPoint,
611 gws.GeometryType.multipolygon: ogr.wkbMultiPolygon,
612 gws.GeometryType.multisurface: ogr.wkbMultiSurface,
613 gws.GeometryType.point: ogr.wkbPoint,
614 gws.GeometryType.polygon: ogr.wkbPolygon,
615 gws.GeometryType.polyhedralsurface: ogr.wkbPolyhedralSurface,
616 gws.GeometryType.surface: ogr.wkbSurface,
617}
619_OGR_TO_GEOM = {
620 ogr.wkbCurve: gws.GeometryType.curve,
621 ogr.wkbGeometryCollection: gws.GeometryType.geometrycollection,
622 ogr.wkbLineString: gws.GeometryType.linestring,
623 ogr.wkbMultiCurve: gws.GeometryType.multicurve,
624 ogr.wkbMultiLineString: gws.GeometryType.multilinestring,
625 ogr.wkbMultiPoint: gws.GeometryType.multipoint,
626 ogr.wkbMultiPolygon: gws.GeometryType.multipolygon,
627 ogr.wkbMultiSurface: gws.GeometryType.multisurface,
628 ogr.wkbPoint: gws.GeometryType.point,
629 ogr.wkbPolygon: gws.GeometryType.polygon,
630 ogr.wkbPolyhedralSurface: gws.GeometryType.polyhedralsurface,
631 ogr.wkbSurface: gws.GeometryType.surface,
632}