Skip to content

Query API

The toshi_hazard_store.query package provides the main public interface for querying hazard data from parquet datasets.

Quick Start

from toshi_hazard_store import query

# Query hazard curves
curves = query.get_hazard_curves(
    location_codes=["-41.300~174.800"],
    vs30s=[400],
    hazard_model="NSHM_v1.0.4",
    imts=["PGA"],
    aggs=["mean"],
    strategy="d1"  # or "d2", "naive"
)

for curve in curves:
    print(f"{curve.imt} at {curve.nloc_001}: {curve.values}")

# Query gridded hazard
gridded = query.get_gridded_hazard(
    location_grid_id="NZ_0_1_NB_1_1",
    hazard_model_ids=["NSHM_v1.0.4"],
    vs30s=[400.0],
    imts=["PGA"],
    aggs=["mean"],
    poes=[0.02, 0.1]
)

for grid in gridded:
    print(f"Grid {grid.location_grid_id} at POE {grid.poe}: {len(grid.accel_levels)} locations")

toshi_hazard_store.query package

Query package for hazard data retrieval.

This package provides the main public interface for querying hazard data from parquet datasets. It includes:

  • Main query functions: get_hazard_curves, get_gridded_hazard
  • Data models: AggregatedHazard, IMTValue
  • Location utilities: downsample_code, get_hashes

AggregatedHazard dataclass

Represents an aggregated hazard dataset.

Attributes:

Name Type Description
compatible_calc_id str

the ID of a compatible calculation for PSHA engines interoperability.

hazard_model_id str

the model that these curves represent.

nloc_001 str

the location string to three places e.g. "-38.330~17.550".

nloc_0 str

the location string to zero places e.g. "-38.0~17.0" (used for partitioning).

imt str

the intensity measure type label e.g. 'PGA', 'SA(5.0)'.

vs30 int

the VS30 integer.

agg str

the aggregation type.

values list[Union[float, IMTValue]]

a list of 44 IMTL values.

Notes

This class is designed to match the table schema for aggregated hazard datasets.

Source code in toshi_hazard_store/query/models.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
@dataclass
class AggregatedHazard:
    """
    Represents an aggregated hazard dataset.

    Attributes:
        compatible_calc_id (str): the ID of a compatible calculation for PSHA engines interoperability.
        hazard_model_id (str): the model that these curves represent.
        nloc_001 (str): the location string to three places e.g. "-38.330~17.550".
        nloc_0 (str): the location string to zero places e.g.  "-38.0~17.0" (used for partitioning).
        imt (str): the intensity measure type label e.g. 'PGA', 'SA(5.0)'.
        vs30 (int): the VS30 integer.
        agg (str): the aggregation type.
        values (list[Union[float, IMTValue]]): a list of 44 IMTL values.

    Notes:
        This class is designed to match the table schema for aggregated hazard datasets.
    """

    compatable_calc_id: str
    hazard_model_id: str

    nloc_001: str
    nloc_0: str
    imt: str
    vs30: int
    agg: str
    values: list["IMTValue"]

    def to_imt_values(self):
        """
        Converts the IMTL values in this object's `values` attribute from a list of floats to a list of `IMTValue`
        objects.
        Returns:
            AggregatedHazard: this object itself.
        """
        new_values = zip(IMT_44_LVLS, self.values)
        self.values = [IMTValue(*x) for x in new_values]
        return self

to_imt_values()

Converts the IMTL values in this object's values attribute from a list of floats to a list of IMTValue objects. Returns: AggregatedHazard: this object itself.

Source code in toshi_hazard_store/query/models.py
 96
 97
 98
 99
100
101
102
103
104
105
def to_imt_values(self):
    """
    Converts the IMTL values in this object's `values` attribute from a list of floats to a list of `IMTValue`
    objects.
    Returns:
        AggregatedHazard: this object itself.
    """
    new_values = zip(IMT_44_LVLS, self.values)
    self.values = [IMTValue(*x) for x in new_values]
    return self

IMTValue dataclass

Represents an intensity measure type (IMT) value.

Attributes:

Name Type Description
lvl float

The level of the IMT value.

val float

The value of the IMT at that level.

Source code in toshi_hazard_store/query/models.py
54
55
56
57
58
59
60
61
62
63
64
@dataclass
class IMTValue:
    """Represents an intensity measure type (IMT) value.

    Attributes:
        lvl: The level of the IMT value.
        val: The value of the IMT at that level.
    """

    lvl: float  # noqa: F821
    val: float  # noqa: F821

downsample_code(loc_code, res)

Get a CodedLocation.code at the chosen resolution from the given location code.

Parameters:

Name Type Description Default
loc_code str

The location code in format 'latitude~longitude'.

required
res int

Resolution in grid degrees to downsample to.

required

Returns:

Name Type Description
str str

The downsampled location code.

Examples:

>>> downsample_code('37.7749~-122.4194', 0.1)
'37.8~-122.4'
Source code in toshi_hazard_store/query/hazard_query.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def downsample_code(loc_code, res) -> str:
    """Get a CodedLocation.code at the chosen resolution from the given location code.

    Args:
        loc_code (str): The location code in format 'latitude~longitude'.
        res (int): Resolution in grid degrees to downsample to.

    Returns:
        str: The downsampled location code.

    Examples:
        >>> downsample_code('37.7749~-122.4194', 0.1)
        '37.8~-122.4'
    """
    lt = loc_code.split("~")
    assert len(lt) == 2
    return CodedLocation(lat=float(lt[0]), lon=float(lt[1]), resolution=res).code

get_gridded_hazard(location_grid_id, hazard_model_ids, vs30s, imts, aggs, poes, dataset_uri=None)

Retrieves gridded hazard from the parquet dataset.

Parameters:

Name Type Description Default
location_grid_id str

the grid identifier to query.

required
hazard_model_ids list

List of hazard model identifiers.

required
vs30s list

List of VS30 values.

required
imts list

List of intensity measure types (e.g. 'PGA', 'SA(5.0)').

required
aggs list

List of aggregation types.

required
poes list

List of probability of exceedance values.

required
dataset_uri Optional[str]

optional URI for the dataset. Defaults to the THS_DATASET_GRIDDED_URI env var.

None

Yields:

Name Type Description
GriddedHazardPoeLevels GriddedHazardPoeLevels

An object containing the gridded hazard data.

Source code in toshi_hazard_store/query/datasets.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def get_gridded_hazard(
    location_grid_id: str,
    hazard_model_ids: list[str],
    vs30s: list[float],
    imts: list[str],
    aggs: list[str],
    poes: list[float],
    dataset_uri: Optional[str] = None,
) -> Iterator[GriddedHazardPoeLevels]:
    """
    Retrieves gridded hazard from the parquet dataset.

    Args:
      location_grid_id: the grid identifier to query.
      hazard_model_ids (list): List of hazard model identifiers.
      vs30s (list): List of VS30 values.
      imts (list): List of intensity measure types (e.g. 'PGA', 'SA(5.0)').
      aggs (list): List of aggregation types.
      poes (list): List of probability of exceedance values.
      dataset_uri: optional URI for the dataset. Defaults to the THS_DATASET_GRIDDED_URI env var.

    Yields:
      GriddedHazardPoeLevels: An object containing the gridded hazard data.
    """

    log.debug("> get_gridded_hazard")
    t0 = dt.datetime.now()

    gridded_dataset = get_gridded_dataset(dataset_uri)
    flt = (
        (pc.field("location_grid_id") == location_grid_id)
        & (pc.field("aggr").isin(aggs))
        & (pc.field("imt").isin(imts))
        & (pc.field("vs30").isin(vs30s))
        & (pc.field("poe").isin(poes))
        & (pc.field("hazard_model_id").isin(hazard_model_ids))
    )

    log.debug(f"filter: {flt}")
    table = gridded_dataset.to_table(filter=flt)

    t1 = dt.datetime.now()
    log.debug(f"to_table for filter took {(t1 - t0).total_seconds()} seconds.")
    log.debug(f"schema {table.schema}")

    # NB the following emulates the method used in AggregatedHazard, but it's significantly slower than
    # below using pa.Table.to_pandas()
    # column_names = table.schema.names
    # for batch in table.to_batches():  # pragma: no branch
    #     for row in zip(*batch.columns):  # pragma: no branch
    #         # count += 1
    #         # print(row)
    #         vals = (x.as_py() for x in row)
    #         item = {x[0]:x[1] for x in zip(column_names, vals)}
    #         obj = GriddedHazardPoeLevels.model_construct(**item)
    #         yield obj

    df0 = table.to_pandas()
    for row_dict in df0.to_dict(orient="records"):
        # yield GriddedHazardPoeLevels(**row_dict) # SLOW because of the expensive validators on
        # this Pydantic Model class.
        yield GriddedHazardPoeLevels.model_construct(**row_dict)  # FAST

get_hashes(locs, resolution=0.1)

Compute a set of hashes for the given locations at the specified resolution.

Parameters:

Name Type Description Default
locs Iterable[str]

A collection of location codes in the format 'latitude~longitude'.

required
resolution float

The resolution to compute hashes at (in grid degrees). Defaults to 0.1.

0.1

Returns:

Name Type Description
list Iterable[str]

A sorted list of unique location codes, downsampled to the specified resolution.

Source code in toshi_hazard_store/query/hazard_query.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def get_hashes(locs: Iterable[str], resolution: float = 0.1) -> Iterable[str]:
    """Compute a set of hashes for the given locations at the specified resolution.

    Args:
        locs (Iterable[str]): A collection of location codes in the format 'latitude~longitude'.
        resolution (float, optional): The resolution to compute hashes at (in grid degrees). Defaults to 0.1.

    Returns:
        list: A sorted list of unique location codes, downsampled to the specified resolution.
    """
    hashes = set()
    for loc in locs:
        lt = loc.split("~")
        assert len(lt) == 2
        hashes.add(downsample_code(loc, resolution))
    return sorted(list(hashes))

get_hazard_curves(location_codes, vs30s, hazard_model, imts, aggs, strategy='naive', dataset_uri=None)

Retrieves aggregated hazard curves from the dataset.

The optional strategy argument can be used to control how the query behaves:

  • 'naive' (the default) lets pyarrow do its normal thing.
  • 'd1' assumes the dataset is partitioned on vs30, generating multiple pyarrow queries from the user args.
  • 'd2' assumes the dataset is partitioned on vs30, nloc_0 and acts accordingly.

These overriding strategies allow the user to tune the query to suit the size of the datasets and the compute resources available. e.g. for the full NSHM, with an AWS lambda function, the d2 option is optimal.

Parameters:

Name Type Description Default
location_codes list

List of location codes.

required
vs30s list

List of VS30 values.

required
hazard_model str

the hazard model id.

required
imts list

List of intensity measure types (e.g. 'PGA', 'SA(5.0)').

required
aggs list

List of aggregation types.

required
strategy str

which query strategy to use (options are d1, d2, naive). Other values will use the naive strategy.

'naive'
dataset_uri Optional[str]

optional URI for the dataset. Defaults to the THS_DATASET_AGGR_URI env var.

None

Yields:

Name Type Description
AggregatedHazard AggregatedHazard

An object containing the aggregated hazard curve data.

Source code in toshi_hazard_store/query/datasets.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def get_hazard_curves(
    location_codes: list[str],
    vs30s: list[int],
    hazard_model: str,
    imts: list[str],
    aggs: list[str],
    strategy: str = "naive",
    dataset_uri: Optional[str] = None,
) -> Iterator[AggregatedHazard]:
    """
    Retrieves aggregated hazard curves from the dataset.

    The optional `strategy` argument can be used to control how the query behaves:

     - 'naive' (the default) lets pyarrow do its normal thing.
     - 'd1' assumes the dataset is partitioned on `vs30`, generating multiple pyarrow queries from the user args.
     - 'd2' assumes the dataset is partitioned on `vs30, nloc_0` and acts accordingly.

    These overriding strategies allow the user to tune the query to suit the size of the datasets and the
    compute resources available. e.g. for the full NSHM, with an AWS lambda function, the `d2` option is optimal.

    Args:
      location_codes (list): List of location codes.
      vs30s (list): List of VS30 values.
      hazard_model: the hazard model id.
      imts (list): List of intensity measure types (e.g. 'PGA', 'SA(5.0)').
      aggs (list): List of aggregation types.
      strategy: which query strategy to use (options are `d1`, `d2`, `naive`).
          Other values will use the `naive` strategy.
      dataset_uri: optional URI for the dataset. Defaults to the THS_DATASET_AGGR_URI env var.

    Yields:
      AggregatedHazard: An object containing the aggregated hazard curve data.
    Raises:
      RuntimeWarning: describing any dataset partitions that could not be opened.
    """
    log.debug("> get_hazard_curves()")
    t0 = dt.datetime.now()

    count = 0

    if strategy == "d2":
        qfn = get_hazard_curves_by_vs30_nloc0
    elif strategy == "d1":
        qfn = get_hazard_curves_by_vs30
    else:
        qfn = get_hazard_curves_naive

    deferred_warning = None
    try:
        for obj in qfn(location_codes, vs30s, hazard_model, imts, aggs, dataset_uri):  # pragma: no branch
            count += 1
            yield obj
    except RuntimeWarning as err:
        if "Failed to open dataset" in str(err):
            deferred_warning = err
        else:
            raise err  # pragma: no cover

    t1 = dt.datetime.now()
    log.info(f"Executed dataset query for {count} curves in {(t1 - t0).total_seconds()} seconds.")

    if deferred_warning:  # pragma: no cover
        raise deferred_warning