173 lines
No EOL
4.4 KiB
Python
173 lines
No EOL
4.4 KiB
Python
|
|
from __future__ import annotations
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
import json
|
|
from math import ceil
|
|
from heapq import heapify, heappush, heappop
|
|
|
|
ALL_DATA = json.load(open('data/network.json'))
|
|
LEGACY_DATA = json.load(open('data/network.1.json'))
|
|
INFINITY = 2147483648
|
|
|
|
@dataclass
|
|
class Station:
|
|
name: str
|
|
code: str
|
|
|
|
@dataclass
|
|
class Line:
|
|
name: str
|
|
code: str
|
|
route: list[RouteStep]
|
|
|
|
def __contains__(self, st: Station | str):
|
|
if isinstance(st, Station):
|
|
st = st.code
|
|
for rs in self.route:
|
|
if rs.origin.code == st or rs.target.code == st:
|
|
return True
|
|
return False
|
|
|
|
@dataclass
|
|
class RouteStep:
|
|
origin: Station
|
|
target: Station
|
|
time: int
|
|
line: str
|
|
|
|
def find_station(code: str) -> Station | None:
|
|
st_name = ALL_DATA.get('stations', {}).get(code)
|
|
if st_name:
|
|
return Station(
|
|
name = st_name,
|
|
code = code
|
|
)
|
|
|
|
def take_first(s):
|
|
if isinstance(s, (str, bytes)):
|
|
return s
|
|
elif hasattr(s, '__iter__'):
|
|
return list(s)[0]
|
|
return s
|
|
|
|
def build_route_list(line_stops: list, line_code):
|
|
route_list = []
|
|
last_step = None
|
|
for step_data in line_stops:
|
|
cur_step = step_data['code']
|
|
if last_step:
|
|
try:
|
|
line_time = ceil(step_data['time'])
|
|
except Exception:
|
|
try:
|
|
line_time = ceil(step_data['dist'] / 64)
|
|
except Exception:
|
|
line_time = 100 # TODO better fallback
|
|
route_list.append(RouteStep(
|
|
origin = find_station(last_step) or Station(last_step, last_step),
|
|
target = find_station(cur_step) or Station(last_step, last_step),
|
|
time = line_time,
|
|
line = line_code
|
|
))
|
|
last_step = cur_step
|
|
return route_list
|
|
|
|
def build_all_lines():
|
|
lines = {}
|
|
for line_data in ALL_DATA['lines']['overworld']:
|
|
|
|
lines[line_data['code']] = Line(
|
|
code = line_data['code'],
|
|
name = line_data['name'],
|
|
route = build_route_list(line_data['stops'], line_data['code'])
|
|
)
|
|
return lines
|
|
|
|
ALL_LINES = build_all_lines()
|
|
|
|
## TODO algorithms of research
|
|
def find_route(start: str, stop: str):
|
|
steps_i = []
|
|
|
|
dist, prev = dijkstra(start)
|
|
|
|
cur = stop
|
|
steps_i.append((find_station(stop), dist[stop]))
|
|
while (cur_prev := prev[cur]) != start:
|
|
steps_i.insert(0, (find_station(cur_prev), dist[cur_prev]))
|
|
cur = cur_prev
|
|
|
|
steps_i.insert(0, (find_station(start), 0))
|
|
|
|
return steps_i
|
|
|
|
@lru_cache()
|
|
def find_neighbors(start):
|
|
neighs = []
|
|
|
|
for line in ALL_LINES.values():
|
|
line: Line
|
|
if start in line:
|
|
for rs in line.route:
|
|
rs: RouteStep
|
|
if rs.origin.code == start:
|
|
neighs.append((rs.target.code, rs.time))
|
|
if rs.target.code == start:
|
|
neighs.append((rs.origin.code, rs.time))
|
|
|
|
return neighs
|
|
|
|
def dijkstra(start: str):
|
|
dist = {node: INFINITY for node in ALL_DATA['stations']}
|
|
dist[start] = 0
|
|
prev = {node: None for node in ALL_DATA['stations']}
|
|
|
|
pq = [(0, start)]
|
|
heapify(pq)
|
|
|
|
visited = set()
|
|
|
|
while pq:
|
|
cur_dist, cur_node = heappop(pq)
|
|
|
|
if cur_node in visited:
|
|
continue
|
|
visited.add(cur_node)
|
|
|
|
for neigh, time in find_neighbors(cur_node):
|
|
tentative_dist = cur_dist + time
|
|
if tentative_dist < dist.setdefault(neigh, INFINITY):
|
|
dist[neigh] = tentative_dist
|
|
prev[neigh] = cur_node
|
|
heappush(pq, (tentative_dist, neigh))
|
|
|
|
return dist, prev
|
|
|
|
|
|
|
|
def main_rails(args):
|
|
if args.legacy:
|
|
global ALL_DATA, ALL_LINES
|
|
ALL_DATA = LEGACY_DATA
|
|
ALL_LINES = build_all_lines()
|
|
|
|
if args.search:
|
|
query = args.search.lower()
|
|
for st_code, st_name in ALL_DATA['stations'].items():
|
|
if query in take_first(st_name).lower():
|
|
print('*', st_code, st_name)
|
|
return
|
|
|
|
st_start = find_station(args.start)
|
|
st_end = find_station(args.end)
|
|
st_time = args.time
|
|
|
|
if not st_start or not st_end:
|
|
print('error: missing stations')
|
|
return
|
|
|
|
route = find_route(st_start.code, st_end.code)
|
|
|
|
for st_step, time in route:
|
|
print(HourMin(st_time + time), st_step.code, st_step.name) |