Coverage for gws-app/gws/gis/gdalx/__init__.py: 0%

333 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-17 01:37 +0200

1"""GDAL/OGR wrapper.""" 

2 

3from typing import Optional, cast 

4 

5import contextlib 

6 

7from osgeo import gdal 

8from osgeo import ogr 

9from osgeo import osr 

10 

11import gws 

12import gws.base.shape 

13import gws.gis.crs 

14import gws.lib.datetimex as datetimex 

15 

16 

17class Error(gws.Error): 

18 pass 

19 

20 

21class DriverInfo(gws.Data): 

22 index: int 

23 name: str 

24 longName: str 

25 metaData: dict 

26 

27 

28class _DriverInfoCache(gws.Data): 

29 infos: list[DriverInfo] 

30 extToName: dict 

31 vectorNames: set[str] 

32 rasterNames: set[str] 

33 

34 

35def drivers() -> list[DriverInfo]: 

36 """Enumerate GDAL drivers.""" 

37 

38 di = _fetch_driver_infos() 

39 return di.infos 

40 

41 

42def open_raster( 

43 path: str, 

44 mode: str = 'r', 

45 driver: str = '', 

46 default_crs: Optional[gws.Crs] = None, 

47 **opts 

48) -> 'RasterDataSet': 

49 """Create a raster DataSet from a path. 

50 

51 Args: 

52 path: File path. 

53 mode: 'r' (=read), 'a' (=update), 'w' (=create/write) 

54 driver: Driver name, if omitted, will be suggested from the path extension. 

55 default_crs: Default CRS for geometries (fallback to Webmercator). 

56 opts: Options for gdal.OpenEx/CreateDataSource. 

57 """ 

58 

59 return _open(path, mode, driver, None, default_crs, opts, True) 

60 

61 

62def open_vector( 

63 path: str, 

64 mode: str = 'r', 

65 driver: str = '', 

66 encoding: str = 'utf8', 

67 default_crs: Optional[gws.Crs] = None, 

68 **opts 

69) -> 'VectorDataSet': 

70 """Create a vector DataSet from a path. 

71 

72 Args: 

73 path: File path. 

74 mode: 'r' (=read), 'a' (=update), 'w' (=create/write) 

75 driver: Driver name, if omitted, will be suggested from the path extension. 

76 encoding: If not None, strings will be automatically decoded. 

77 default_crs: Default CRS for geometries (fallback to Webmercator). 

78 opts: Options for gdal.OpenEx/CreateDataSource. 

79 

80 Returns: 

81 DataSet object. 

82 

83 """ 

84 

85 return _open(path, mode, driver, encoding, default_crs, opts, False) 

86 

87 

88def _open(path, mode, driver, encoding, default_crs, opts, need_raster): 

89 if not mode: 

90 mode = 'r' 

91 if mode not in 'rwa': 

92 raise Error(f'invalid open mode {mode!r}') 

93 

94 gdal.UseExceptions() 

95 

96 drv = _driver_from_args(path, driver, need_raster) 

97 default_crs = default_crs or gws.gis.crs.WEBMERCATOR 

98 

99 if mode == 'w': 

100 gd = drv.CreateDataSource(path, **opts) 

101 if gd is None: 

102 raise Error(f'cannot create {path!r}') 

103 if need_raster: 

104 return RasterDataSet(path, gd, encoding, default_crs) 

105 return VectorDataSet(path, gd, encoding, default_crs) 

106 

107 if not gws.u.is_file(path): 

108 raise Error(f'file not found {path!r}') 

109 

110 flags = gdal.OF_VERBOSE_ERROR 

111 if mode == 'r': 

112 flags += gdal.OF_READONLY 

113 if mode == 'a': 

114 flags += gdal.OF_UPDATE 

115 if need_raster: 

116 flags += gdal.OF_RASTER 

117 else: 

118 flags += gdal.OF_VECTOR 

119 

120 gd = gdal.OpenEx(path, flags, **opts) 

121 if gd is None: 

122 raise Error(f'cannot open {path!r}') 

123 

124 if need_raster: 

125 return RasterDataSet(path, gd, encoding, default_crs) 

126 return VectorDataSet(path, gd, encoding, default_crs) 

127 

128 

129def open_from_image(image: gws.Image, bounds: gws.Bounds) -> 'RasterDataSet': 

130 """Create an in-memory Dataset from an Image. 

131 

132 Args: 

133 image: Image object 

134 bounds: geographic bounds 

135 """ 

136 

137 gdal.UseExceptions() 

138 

139 drv = gdal.GetDriverByName('MEM') 

140 img_array = image.to_array() 

141 band_count = img_array.shape[2] 

142 

143 gd = drv.Create('', img_array.shape[1], img_array.shape[0], band_count, gdal.GDT_Byte) 

144 for band in range(band_count): 

145 gd.GetRasterBand(band + 1).WriteArray(img_array[:, :, band]) 

146 

147 ext = bounds.extent 

148 

149 src_res_x = (ext[2] - ext[0]) / gd.RasterXSize 

150 src_res_y = (ext[1] - ext[3]) / gd.RasterYSize 

151 

152 src_transform = ( 

153 ext[0], 

154 src_res_x, 

155 0, 

156 ext[3], 

157 0, 

158 src_res_y, 

159 ) 

160 

161 gd.SetGeoTransform(src_transform) 

162 gd.SetSpatialRef(_srs_from_srid(bounds.crs.srid)) 

163 

164 return RasterDataSet('', gd) 

165 

166 

167## 

168 

169class _DataSet: 

170 gdDataset: gdal.Dataset 

171 gdDriver: gdal.Driver 

172 path: str 

173 driverName: str 

174 defaultCrs: gws.Crs 

175 encoding: str 

176 

177 def __init__(self, path, gd_dataset, encoding='utf8', default_crs: Optional[gws.Crs] = None): 

178 self.path = path 

179 self.gdDataset = gd_dataset 

180 self.gdDriver = self.gdDataset.GetDriver() 

181 self.driverName = self.gdDriver.GetDescription() 

182 self.encoding = encoding 

183 self.defaultCrs = default_crs or gws.gis.crs.WEBMERCATOR 

184 

185 def __enter__(self): 

186 return self 

187 

188 def __exit__(self, exc_type, exc_val, exc_tb): 

189 self.close() 

190 return False 

191 

192 def close(self): 

193 self.gdDataset.FlushCache() 

194 setattr(self, 'gdDataset', None) 

195 

196 

197class RasterDataSet(_DataSet): 

198 def create_copy(self, path: str, driver: str = '', strict=False, **opts): 

199 """Create a copy of a DataSet.""" 

200 

201 gdal.UseExceptions() 

202 

203 drv = _driver_from_args(path, driver, need_raster=True) 

204 gd = drv.CreateCopy(path, self.gdDataset, 1 if strict else 0, **opts) 

205 gd.SetMetadata(self.gdDataset.GetMetadata()) 

206 gd.FlushCache() 

207 gd = None 

208 

209 

210class VectorDataSet(_DataSet): 

211 @contextlib.contextmanager 

212 def transaction(self): 

213 self.gdDataset.StartTransaction() 

214 try: 

215 yield self 

216 self.gdDataset.CommitTransaction() 

217 except: 

218 self.gdDataset.RollbackTransaction() 

219 raise 

220 

221 def create_layer( 

222 self, 

223 name: str, 

224 columns: dict[str, gws.AttributeType], 

225 geometry_type: gws.GeometryType = None, 

226 crs: gws.Crs = None, 

227 overwrite=False, 

228 *options, 

229 ) -> 'VectorLayer': 

230 opts = list(options) 

231 if overwrite: 

232 opts.append('OVERWRITE=YES') 

233 

234 geom_type = ogr.wkbUnknown 

235 srs = None 

236 

237 if geometry_type: 

238 geom_type = _GEOM_TO_OGR.get(geometry_type) 

239 if not geom_type: 

240 gws.log.warning(f'gdal: unsupported {geometry_type=}') 

241 geom_type = ogr.wkbUnknown 

242 crs = crs or self.defaultCrs 

243 srs = _srs_from_srid(crs.srid) 

244 

245 gd_layer = self.gdDataset.CreateLayer( 

246 name, 

247 geom_type=geom_type, 

248 srs=srs, 

249 options=opts, 

250 ) 

251 for col_name, col_type in columns.items(): 

252 gd_layer.CreateField(ogr.FieldDefn(col_name, _ATTR_TO_OGR[col_type])) 

253 

254 return VectorLayer(self, gd_layer) 

255 

256 def layers(self) -> list['VectorLayer']: 

257 cnt = self.gdDataset.GetLayerCount() 

258 return [VectorLayer(self, self.gdDataset.GetLayerByIndex(n)) for n in range(cnt)] 

259 

260 def layer(self, name_or_index: str | int) -> Optional['VectorLayer']: 

261 gd_layer = self.gdDataset.GetLayer(name_or_index) 

262 return VectorLayer(self, gd_layer) if gd_layer else None 

263 

264 

265class VectorLayer: 

266 name: str 

267 gdLayer: ogr.Layer 

268 gdDefn: ogr.FeatureDefn 

269 encoding: str 

270 defaultCrs: gws.Crs 

271 

272 def __init__(self, ds: VectorDataSet, gd_layer: ogr.Layer): 

273 self.gdLayer = gd_layer 

274 self.gdDefn = self.gdLayer.GetLayerDefn() 

275 self.name = self.gdDefn.GetName() 

276 self.encoding = ds.encoding 

277 self.defaultCrs = ds.defaultCrs 

278 

279 def describe(self) -> gws.DataSetDescription: 

280 desc = gws.DataSetDescription( 

281 columns=[], 

282 columnMap={}, 

283 fullName=self.name, 

284 geometryName='', 

285 geometrySrid=0, 

286 geometryType='', 

287 name=self.name, 

288 schema='', 

289 ) 

290 

291 cols = [] 

292 

293 fid_col = self.gdLayer.GetFIDColumn() 

294 if fid_col: 

295 cols.append(gws.ColumnDescription( 

296 name=fid_col, 

297 type=_OGR_TO_ATTR[ogr.OFTInteger], 

298 nativeType=ogr.OFTInteger, 

299 isPrimaryKey=True, 

300 columnIndex=0, 

301 )) 

302 

303 for i in range(self.gdDefn.GetFieldCount()): 

304 fdef: ogr.FieldDefn = self.gdDefn.GetFieldDefn(i) 

305 typ = fdef.GetType() 

306 if typ not in _OGR_TO_ATTR: 

307 continue 

308 cols.append(gws.ColumnDescription( 

309 name=fdef.GetName(), 

310 type=_OGR_TO_ATTR[typ], 

311 nativeType=typ, 

312 columnIndex=i, 

313 )) 

314 

315 for i in range(self.gdDefn.GetGeomFieldCount()): 

316 fdef: ogr.GeomFieldDefn = self.gdDefn.GetGeomFieldDefn(i) 

317 crs: osr.SpatialReference = fdef.GetSpatialRef() 

318 typ = fdef.GetType() 

319 if typ not in _OGR_TO_GEOM: 

320 continue 

321 cols.append(gws.ColumnDescription( 

322 name=fdef.GetName(), 

323 type=gws.AttributeType.geometry, 

324 nativeType=typ, 

325 columnIndex=i, 

326 geometryType=_OGR_TO_GEOM[typ], 

327 geometrySrid=int(crs.GetAuthorityCode(None)), 

328 )) 

329 

330 desc.columns = cols 

331 desc.columnMap = {c.name: c for c in cols} 

332 

333 for c in cols: 

334 # NB take the last geom 

335 if c.geometryType: 

336 desc.geometryName = c.name 

337 desc.geometryType = c.geometryType 

338 desc.geometrySrid = c.geometrySrid 

339 

340 return desc 

341 

342 def insert(self, records: list[gws.FeatureRecord]) -> list[int]: 

343 desc = self.describe() 

344 fids = [] 

345 

346 for rec in records: 

347 gd_feature = ogr.Feature(self.gdDefn) 

348 if desc.geometryType and rec.shape: 

349 gd_feature.SetGeometry( 

350 ogr.CreateGeometryFromWkt( 

351 rec.shape.to_wkt(), 

352 _srs_from_srid(rec.shape.crs.srid) 

353 )) 

354 

355 if rec.uid and isinstance(rec.uid, int): 

356 gd_feature.SetFID(rec.uid) 

357 

358 for col in desc.columns: 

359 if col.geometryType or col.isPrimaryKey: 

360 continue 

361 val = rec.attributes.get(col.name) 

362 if val is None: 

363 continue 

364 try: 

365 _attr_to_ogr(gd_feature, int(col.nativeType), col.columnIndex, val, self.encoding) 

366 except Exception as exc: 

367 raise Error(f'field cannot be set: {col.name=} {val=}') from exc 

368 

369 self.gdLayer.CreateFeature(gd_feature) 

370 fids.append(gd_feature.GetFID()) 

371 

372 return fids 

373 

374 def count(self, force=False): 

375 return self.gdLayer.GetFeatureCount(force=1 if force else 0) 

376 

377 def get_all(self) -> list[gws.FeatureRecord]: 

378 records = [] 

379 

380 self.gdLayer.ResetReading() 

381 

382 while True: 

383 gd_feature = self.gdLayer.GetNextFeature() 

384 if not gd_feature: 

385 break 

386 records.append(self._feature_record(gd_feature)) 

387 

388 return records 

389 

390 def get(self, fid: int) -> Optional[gws.FeatureRecord]: 

391 gd_feature = self.gdLayer.GetFeature(fid) 

392 if gd_feature: 

393 return self._feature_record(gd_feature) 

394 

395 def _feature_record(self, gd_feature): 

396 rec = gws.FeatureRecord( 

397 attributes={}, 

398 shape=None, 

399 meta={'layerName': self.name}, 

400 uid=str(gd_feature.GetFID()), 

401 ) 

402 

403 for i in range(gd_feature.GetFieldCount()): 

404 gd_field_defn: ogr.FieldDefn = gd_feature.GetFieldDefnRef(i) 

405 name = gd_field_defn.GetName() 

406 val = _attr_from_ogr(gd_feature, gd_field_defn.type, i, self.encoding) 

407 rec.attributes[name] = val 

408 

409 cnt = gd_feature.GetGeomFieldCount() 

410 if cnt > 0: 

411 # NB take the last geom 

412 # @TODO multigeometry support 

413 gd_geom_defn = gd_feature.GetGeomFieldRef(cnt - 1) 

414 if gd_geom_defn: 

415 srs = gd_geom_defn.GetSpatialReference() 

416 if srs: 

417 srid = srs.GetAuthorityCode(None) 

418 crs = gws.gis.crs.get(srid) 

419 else: 

420 crs = self.defaultCrs 

421 wkt = gd_geom_defn.ExportToWkt() 

422 rec.shape = gws.base.shape.from_wkt(wkt, crs) 

423 

424 return rec 

425 

426 

427## 

428 

429 

430def _driver_from_args(path, driver_name, need_raster): 

431 di = _fetch_driver_infos() 

432 

433 if not driver_name: 

434 ext = path.split('.')[-1] 

435 names = di.extToName.get(ext) 

436 if not names: 

437 raise Error(f'no default driver found for {path!r}') 

438 if len(names) > 1: 

439 raise Error(f'multiple drivers found for {path!r}: {names}') 

440 driver_name = names[0] 

441 

442 is_vector = driver_name in di.vectorNames 

443 is_raster = driver_name in di.rasterNames 

444 

445 if need_raster: 

446 if not is_raster: 

447 raise Error(f'driver {driver_name!r} is not raster') 

448 return gdal.GetDriverByName(driver_name) 

449 

450 if not is_vector: 

451 raise Error(f'driver {driver_name!r} is not vector') 

452 return ogr.GetDriverByName(driver_name) 

453 

454 

455_di_cache: Optional[_DriverInfoCache] = None 

456 

457 

458def _fetch_driver_infos() -> _DriverInfoCache: 

459 global _di_cache 

460 

461 if _di_cache: 

462 return _di_cache 

463 

464 _di_cache = _DriverInfoCache( 

465 infos=[], 

466 extToName={}, 

467 vectorNames=set(), 

468 rasterNames=set(), 

469 ) 

470 

471 for n in range(gdal.GetDriverCount()): 

472 drv = gdal.GetDriver(n) 

473 inf = DriverInfo( 

474 index=n, 

475 name=str(drv.ShortName), 

476 longName=str(drv.LongName), 

477 metaData=dict(drv.GetMetadata() or {}) 

478 ) 

479 _di_cache.infos.append(inf) 

480 

481 for e in inf.metaData.get(gdal.DMD_EXTENSIONS, '').split(): 

482 _di_cache.extToName.setdefault(e, []).append(inf.name) 

483 if inf.metaData.get('DCAP_VECTOR') == 'YES': 

484 _di_cache.vectorNames.add(inf.name) 

485 if inf.metaData.get('DCAP_RASTER') == 'YES': 

486 _di_cache.rasterNames.add(inf.name) 

487 

488 return _di_cache 

489 

490 

491_srs_cache = {} 

492 

493 

494def _srs_from_srid(srid): 

495 if srid not in _srs_cache: 

496 _srs_cache[srid] = osr.SpatialReference() 

497 _srs_cache[srid].ImportFromEPSG(srid) 

498 return _srs_cache[srid] 

499 

500 

501def _attr_from_ogr(gd_feature: ogr.Feature, gtype: int, idx: int, encoding: str = 'utf8'): 

502 if gd_feature.IsFieldNull(idx): 

503 return None 

504 

505 if gtype == ogr.OFTString: 

506 b = gd_feature.GetFieldAsBinary(idx) 

507 if encoding: 

508 return b.decode(encoding) 

509 return b 

510 

511 if gtype in {ogr.OFTDate, ogr.OFTTime, ogr.OFTDateTime}: 

512 # python GetFieldAsDateTime appears to use float seconds, as in 

513 # GetFieldAsDateTime (int i, int *pnYear, int *pnMonth, int *pnDay, int *pnHour, int *pnMinute, float *pfSecond, int *pnTZFlag) 

514 # 

515 v = gd_feature.GetFieldAsDateTime(idx) 

516 sec, fsec = divmod(v[5], 1) 

517 try: 

518 return datetimex.new(v[0], v[1], v[2], v[3], v[4], int(sec), int(fsec * 1e6), tz=_tzflag_to_tz(v[6])) 

519 except ValueError: 

520 return 

521 

522 if gtype == ogr.OFSTBoolean: 

523 return gd_feature.GetFieldAsInteger(idx) != 0 

524 if gtype in {ogr.OFTInteger, ogr.OFTInteger64}: 

525 return gd_feature.GetFieldAsInteger(idx) 

526 if gtype in {ogr.OFTIntegerList, ogr.OFTInteger64List}: 

527 return gd_feature.GetFieldAsIntegerList(idx) 

528 if gtype in {ogr.OFTReal, ogr.OFSTFloat32}: 

529 return gd_feature.GetFieldAsDouble(idx) 

530 if gtype == ogr.OFTRealList: 

531 return gd_feature.GetFieldAsDoubleList(idx) 

532 if gtype == ogr.OFTBinary: 

533 return gd_feature.GetFieldAsBinary(idx) 

534 

535 

536def _tzflag_to_tz(tzflag): 

537 # see gdal/ogr/ogrutils.cpp OGRGetISO8601DateTime 

538 

539 if tzflag == 0 or tzflag == 1: 

540 return '' 

541 if tzflag == 100: 

542 return 'UTC' 

543 if tzflag % 4 != 0: 

544 # @TODO 

545 raise Error(f'unsupported timezone {tzflag=}') 

546 hrs = (100 - tzflag) // 4 

547 return f'Etc/GMT{hrs:+}' 

548 

549 

550def _attr_to_ogr(gd_feature: ogr.Feature, gtype: int, idx: int, value, encoding): 

551 if gtype == ogr.OFTDate: 

552 return gd_feature.SetField(idx, datetimex.to_iso_date_string(value)) 

553 if gtype == ogr.OFTTime: 

554 return gd_feature.SetField(idx, datetimex.to_iso_time_string(value)) 

555 if gtype == ogr.OFTDateTime: 

556 return gd_feature.SetField(idx, datetimex.to_iso_string(value)) 

557 if gtype == ogr.OFSTBoolean: 

558 return gd_feature.SetField(idx, bool(value)) 

559 if gtype in {ogr.OFTInteger, ogr.OFTInteger64}: 

560 return gd_feature.SetField(idx, int(value)) 

561 if gtype in {ogr.OFTIntegerList, ogr.OFTInteger64List}: 

562 return gd_feature.SetField(idx, [int(x) for x in value]) 

563 if gtype in {ogr.OFTReal, ogr.OFSTFloat32}: 

564 return gd_feature.SetField(idx, float(value)) 

565 if gtype == ogr.OFTRealList: 

566 return gd_feature.SetField(idx, [float(x) for x in value]) 

567 

568 return gd_feature.SetField(idx, value) 

569 

570 

571def is_attribute_supported(typ): 

572 return typ in _ATTR_TO_OGR 

573 

574 

575_ATTR_TO_OGR = { 

576 gws.AttributeType.bool: ogr.OFTInteger, 

577 gws.AttributeType.bytes: ogr.OFTBinary, 

578 gws.AttributeType.date: ogr.OFTDate, 

579 gws.AttributeType.datetime: ogr.OFTDateTime, 

580 gws.AttributeType.float: ogr.OFTReal, 

581 gws.AttributeType.floatlist: ogr.OFTRealList, 

582 gws.AttributeType.int: ogr.OFTInteger, 

583 gws.AttributeType.intlist: ogr.OFTIntegerList, 

584 gws.AttributeType.str: ogr.OFTString, 

585 gws.AttributeType.strlist: ogr.OFTStringList, 

586 gws.AttributeType.time: ogr.OFTTime, 

587} 

588 

589_OGR_TO_ATTR = { 

590 ogr.OFTBinary: gws.AttributeType.bytes, 

591 ogr.OFTDate: gws.AttributeType.date, 

592 ogr.OFTDateTime: gws.AttributeType.datetime, 

593 ogr.OFTReal: gws.AttributeType.float, 

594 ogr.OFTRealList: gws.AttributeType.floatlist, 

595 ogr.OFTInteger: gws.AttributeType.int, 

596 ogr.OFTIntegerList: gws.AttributeType.intlist, 

597 ogr.OFTInteger64: gws.AttributeType.int, 

598 ogr.OFTInteger64List: gws.AttributeType.intlist, 

599 ogr.OFTString: gws.AttributeType.str, 

600 ogr.OFTStringList: gws.AttributeType.strlist, 

601 ogr.OFTTime: gws.AttributeType.time, 

602} 

603 

604_GEOM_TO_OGR = { 

605 gws.GeometryType.curve: ogr.wkbCurve, 

606 gws.GeometryType.geometrycollection: ogr.wkbGeometryCollection, 

607 gws.GeometryType.linestring: ogr.wkbLineString, 

608 gws.GeometryType.multicurve: ogr.wkbMultiCurve, 

609 gws.GeometryType.multilinestring: ogr.wkbMultiLineString, 

610 gws.GeometryType.multipoint: ogr.wkbMultiPoint, 

611 gws.GeometryType.multipolygon: ogr.wkbMultiPolygon, 

612 gws.GeometryType.multisurface: ogr.wkbMultiSurface, 

613 gws.GeometryType.point: ogr.wkbPoint, 

614 gws.GeometryType.polygon: ogr.wkbPolygon, 

615 gws.GeometryType.polyhedralsurface: ogr.wkbPolyhedralSurface, 

616 gws.GeometryType.surface: ogr.wkbSurface, 

617} 

618 

619_OGR_TO_GEOM = { 

620 ogr.wkbCurve: gws.GeometryType.curve, 

621 ogr.wkbGeometryCollection: gws.GeometryType.geometrycollection, 

622 ogr.wkbLineString: gws.GeometryType.linestring, 

623 ogr.wkbMultiCurve: gws.GeometryType.multicurve, 

624 ogr.wkbMultiLineString: gws.GeometryType.multilinestring, 

625 ogr.wkbMultiPoint: gws.GeometryType.multipoint, 

626 ogr.wkbMultiPolygon: gws.GeometryType.multipolygon, 

627 ogr.wkbMultiSurface: gws.GeometryType.multisurface, 

628 ogr.wkbPoint: gws.GeometryType.point, 

629 ogr.wkbPolygon: gws.GeometryType.polygon, 

630 ogr.wkbPolyhedralSurface: gws.GeometryType.polyhedralsurface, 

631 ogr.wkbSurface: gws.GeometryType.surface, 

632}