diff --git a/simple/stats/events_importer.py b/simple/stats/events_importer.py index c4a49caa..636142fa 100644 --- a/simple/stats/events_importer.py +++ b/simple/stats/events_importer.py @@ -55,6 +55,8 @@ def __init__(self, input_fh: FileHandler, db: Db, self.entity_type = self.config.entity_type(self.input_file_name) self.ignore_columns = self.config.ignore_columns(self.input_file_name) self.provenance = self.nodes.provenance(self.input_file_name).id + # Reassign after reading CSV. + self.entity_column_name = constants.COLUMN_DCID self.event_type = self.config.event_type(self.input_file_name) assert self.event_type, f"Event type must be specified: {self.input_file_name}" @@ -70,8 +72,8 @@ def do_import(self) -> None: self._read_csv() self._drop_ignored_columns() self._sanitize_values() - self._resolve_entities() self._rename_columns() + self._resolve_entities() self._write_event_triples() self._write_observations() self.reporter.report_success() @@ -91,6 +93,8 @@ def _read_csv(self) -> None: skipinitialspace=True, thousands=",") logging.info("Read %s rows.", self.df.index.size) + self.entity_column_name = self.df.columns[0] + logging.info("Entity column name: %s", self.entity_column_name) def _drop_ignored_columns(self): if self.ignore_columns: @@ -203,8 +207,8 @@ def _write_event_triples(self) -> None: def _resolve_entities(self) -> None: df = self.df - # get first (0th) column - column = df.iloc[:, 0] + # get entity column + column = df[constants.COLUMN_DCID] pre_resolved_entities = {} @@ -221,17 +225,21 @@ def remove_pre_resolved(entity: str) -> bool: logging.info("Resolving %s entities of type %s.", len(entities), self.entity_type) - dcids = self._resolve(entity_column_name=df.columns[0], entities=entities) + dcids = self._resolve(entities=entities) logging.info("Resolved %s of %s entities.", len(dcids), len(entities)) # Replace resolved entities. - column.replace(dcids, inplace=True) + # NOTE: column.map performs much better than column.replace, hence using the former. + column = column.map(lambda x: dcids.get(x, x)) unresolved = set(entities).difference(set(dcids.keys())) unresolved_list = sorted(list(unresolved)) # Replace pre-resolved entities without the "dcid:" prefix. - column.replace(pre_resolved_entities, inplace=True) + column = column.map(lambda x: pre_resolved_entities.get(x, x)) + logging.info("Replaced %s pre-resolved entities.", + len(pre_resolved_entities)) + df[constants.COLUMN_DCID] = column if unresolved_list: logging.warning("# unresolved entities which will be dropped: %s", len(unresolved_list)) @@ -244,9 +252,8 @@ def remove_pre_resolved(entity: str) -> bool: unresolved=unresolved_list, ) - def _resolve(self, entity_column_name: str, - entities: list[str]) -> dict[str, str]: - lower_case_entity_name = entity_column_name.lower() + def _resolve(self, entities: list[str]) -> dict[str, str]: + lower_case_entity_name = self.entity_column_name.lower() # Check if the entities can be resolved locally. # If so, return them by prefixing the values as appropriate. diff --git a/simple/stats/observations_importer.py b/simple/stats/observations_importer.py index 14779b54..e403bf75 100644 --- a/simple/stats/observations_importer.py +++ b/simple/stats/observations_importer.py @@ -48,6 +48,8 @@ def __init__(self, input_fh: FileHandler, db: Db, self.config = nodes.config self.entity_type = self.config.entity_type(self.input_file_name) self.ignore_columns = self.config.ignore_columns(self.input_file_name) + # Reassign after reading CSV. + self.entity_column_name = constants.COLUMN_DCID self.df = pd.DataFrame() self.debug_resolve_df = None @@ -57,8 +59,8 @@ def do_import(self) -> None: self._read_csv() self._drop_ignored_columns() self._sanitize_values() - self._resolve_entities() self._rename_columns() + self._resolve_entities() self._add_provenance_column() self._add_entity_nodes() self._write_observations() @@ -79,6 +81,8 @@ def _read_csv(self) -> None: skipinitialspace=True, thousands=",") logging.info("Read %s rows.", self.df.index.size) + self.entity_column_name = self.df.columns[0] + logging.info("Entity column name: %s", self.entity_column_name) def _drop_ignored_columns(self): if self.ignore_columns: @@ -158,8 +162,8 @@ def _resolve_entity_type(self) -> str: def _resolve_entities(self) -> None: df = self.df - # get first (0th) column - column = df.iloc[:, 0] + # get entity column + column = df[constants.COLUMN_DCID] pre_resolved_entities = {} @@ -176,17 +180,19 @@ def remove_pre_resolved(entity: str) -> bool: logging.info("Resolving %s entities of type %s.", len(entities), self.entity_type) - dcids = self._resolve(entity_column_name=df.columns[0], entities=entities) + dcids = self._resolve(entities=entities) logging.info("Resolved %s of %s entities.", len(dcids), len(entities)) # Replace resolved entities. - column.replace(dcids, inplace=True) + # NOTE: column.map performs much better than column.replace, hence using the former. + column = column.map(lambda x: dcids.get(x, x)) unresolved = set(entities).difference(set(dcids.keys())) unresolved_list = sorted(list(unresolved)) # Replace pre-resolved entities without the "dcid:" prefix. - column.replace(pre_resolved_entities, inplace=True) + column = column.map(lambda x: pre_resolved_entities.get(x, x)) + df[constants.COLUMN_DCID] = column if unresolved_list: logging.warning("# unresolved entities which will be dropped: %s", len(unresolved_list)) @@ -199,9 +205,8 @@ def remove_pre_resolved(entity: str) -> bool: unresolved=unresolved_list, ) - def _resolve(self, entity_column_name: str, - entities: list[str]) -> dict[str, str]: - lower_case_entity_name = entity_column_name.lower() + def _resolve(self, entities: list[str]) -> dict[str, str]: + lower_case_entity_name = self.entity_column_name.lower() # Check if the entities can be resolved locally. # If so, return them by prefixing the values as appropriate.