Skip to content

Commit

Permalink
Merge pull request #11 from atsushi-green/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
atsushi-green authored Sep 1, 2023
2 parents 29e4458 + cfbf8d9 commit e448bc4
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 223 deletions.
8 changes: 2 additions & 6 deletions index.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
<meta charset="UTF-8">
<title>station2vec</title>
<meta name="viewport" content="width=device-width">
<link rel="stylesheet" href="my_style.css">
<link rel="icon" type="image/png" href="https://avatars.githubusercontent.com/u/129949522?v=4">
</head>

<body>
<p>
<label>駅名:<input type="text" id="targetStation" size="40" value="渋谷"></label>
<input type="button" value="「近い駅」を検索" id="searchButton">
<input type="button" value="似ている駅を検索" id="searchButton">

</p>
<p id="msg"></p>
Expand All @@ -23,7 +22,7 @@
<p id="msg2"></p>
<script src="./station2vec.js">
</script>
<p> 以下の駅一覧から選んで、テキストボックスに入力し、「「近い駅」を検索」ボタンを押してください。</p>
<p> 以下の駅一覧から選んで、テキストボックスに入力し、「似ている駅を検索」ボタンを押してください。</p>
<p> 東急 新横浜線/世田谷線は対象外としています。</p>

<table border="1">
Expand Down Expand Up @@ -380,9 +379,6 @@
<td>こどもの国</td>
</tr>




</body>

</html>
54 changes: 32 additions & 22 deletions scripts/MeshPopulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import numpy as np
import pandas as pd

DAY_FLAG = 1 # 平日
HOLIDAY_FLG = 0 # 休日
WEEKDAY_FLAG = 1 # 平日
TIME_ZONE_NOON = 0 # 昼
TIME_ZONE_NIGHT = 1 # 深夜

Expand All @@ -36,11 +37,19 @@ def __init__(self, populatiopn_filepaths: List[Path], mesh_filepath: Path) -> No
df_list.append(df)
population_df = pd.concat(df_list)

# 平日のみ抽出
population_df = population_df[population_df["dayflag"] == DAY_FLAG]
# 平日休日それぞれ抽出
weekday_population_df = population_df[population_df["dayflag"] == WEEKDAY_FLAG]
holiday_population_df = population_df[population_df["dayflag"] == HOLIDAY_FLG]

# 昼と深夜をそれぞれ抽出
self.noon_population_df = population_df[population_df["timezone"] == TIME_ZONE_NOON]
self.night_population_df = population_df[population_df["timezone"] == TIME_ZONE_NIGHT]
# 平日昼
self.weekday_noon_population_df = weekday_population_df[weekday_population_df["timezone"] == TIME_ZONE_NOON]
# 平日深夜
self.weekday_night_population_df = weekday_population_df[weekday_population_df["timezone"] == TIME_ZONE_NIGHT]
# 休日昼
self.holiday_noon_population_df = holiday_population_df[holiday_population_df["timezone"] == TIME_ZONE_NOON]
# 休日深夜
self.holiday_night_population_df = holiday_population_df[holiday_population_df["timezone"] == TIME_ZONE_NIGHT]

# 全国のメッシュ情報を読み込む
self.mesh_df = pd.read_csv(mesh_filepath, dtype={"mesh1kmid": str})
Expand All @@ -62,27 +71,28 @@ def search_mesh_id(self, lon: float, lat: float) -> str:
# メッシュIDを返す
return self.mesh_df.iloc[min_idx]["mesh1kmid"]

def get_noon_population(self, mesh_id: str) -> float:
"""与えられたメッシュID地点の昼の人口を返す
def get_population(self, mesh_id: str, noon_night: str, weekday_holiday: str) -> float:
"""与えられたメッシュID地点の昼深夜・平日休日の人口を返す
Args:
mesh_id (str): 全国のメッシュID
noon_night (str): "noon" or "night"
weekday_holiday (str): "weekday" or "holiday"
Returns:
float: 昼人口
"""
# メッシュIDから人口を返す
return self.noon_population_df.loc[mesh_id]["population"]

def get_night_population(self, mesh_id: str) -> float:
"""与えられたメッシュID地点の深夜の人口を返す
Args:
mesh_id (str): 全国のメッシュID
Raises:
Exception: _description_
Returns:
float: 深夜人口
float: 条件に合致する人口
"""

# メッシュIDから人口を返す
return self.night_population_df.loc[mesh_id]["population"]
if noon_night == "noon":
if weekday_holiday == "weekday":
return self.weekday_noon_population_df.loc[mesh_id]["population"]
elif weekday_holiday == "holiday":
return self.holiday_noon_population_df.loc[mesh_id]["population"]
else:
if weekday_holiday == "weekday":
return self.weekday_night_population_df.loc[mesh_id]["population"]
elif weekday_holiday == "holiday":
return self.holiday_night_population_df.loc[mesh_id]["population"]
raise Exception("invalid argument")
11 changes: 6 additions & 5 deletions scripts/StationData.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from torch_geometric.transforms import NormalizeFeatures

# 与えるノード特徴量
USE_FEATURES = ["地価", "急行", "次数", "昼人口", "深夜人口", "昼夜人口差"]
# TODO: 土日と平日で分ける
USE_FEATURES = ["地価", "次数", "平日昼人口", "平日深夜人口", "平日昼夜人口差", "急行"]
USE_FEATURES = ["地価", "次数", "平日昼人口", "平日深夜人口", "休日昼人口", "平日昼夜人口差", "急行"]
# SQUARED_INDEXES = [0, 1, 2, 3, 4]
SQUARED_INDEXES = [0, 1, 2, 3, 4, 5]

CROSS_ENTROPY_INDEXES = [6]


class StationData(InMemoryDataset):
Expand Down Expand Up @@ -109,9 +113,6 @@ def calc_graph_distance(self, station_df: pd.DataFrame, edge_df: pd.DataFrame, s
edge_list.extend([i, i] for i in range(len(station2id))) # 自己ループを追加
edge_attr.extend([[0, 0]] * len(station2id)) # 自己ループを追加

print(distance_matrix)
print(hop_matrix)

return edge_list, edge_attr

def train_val_test_split(
Expand Down
Loading

0 comments on commit e448bc4

Please sign in to comment.