micorail/src/micorail/rails.py

173 lines
No EOL
4.4 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
import json
from math import ceil
from heapq import heapify, heappush, heappop
ALL_DATA = json.load(open('data/network.json'))
LEGACY_DATA = json.load(open('data/network.1.json'))
INFINITY = 2147483648
@dataclass
class Station:
name: str
code: str
@dataclass
class Line:
name: str
code: str
route: list[RouteStep]
def __contains__(self, st: Station | str):
if isinstance(st, Station):
st = st.code
for rs in self.route:
if rs.origin.code == st or rs.target.code == st:
return True
return False
@dataclass
class RouteStep:
origin: Station
target: Station
time: int
line: str
def find_station(code: str) -> Station | None:
st_name = ALL_DATA.get('stations', {}).get(code)
if st_name:
return Station(
name = st_name,
code = code
)
def take_first(s):
if isinstance(s, (str, bytes)):
return s
elif hasattr(s, '__iter__'):
return list(s)[0]
return s
def build_route_list(line_stops: list, line_code):
route_list = []
last_step = None
for step_data in line_stops:
cur_step = step_data['code']
if last_step:
try:
line_time = ceil(step_data['time'])
except Exception:
try:
line_time = ceil(step_data['dist'] / 64)
except Exception:
line_time = 100 # TODO better fallback
route_list.append(RouteStep(
origin = find_station(last_step) or Station(last_step, last_step),
target = find_station(cur_step) or Station(last_step, last_step),
time = line_time,
line = line_code
))
last_step = cur_step
return route_list
def build_all_lines():
lines = {}
for line_data in ALL_DATA['lines']['overworld']:
lines[line_data['code']] = Line(
code = line_data['code'],
name = line_data['name'],
route = build_route_list(line_data['stops'], line_data['code'])
)
return lines
ALL_LINES = build_all_lines()
## TODO algorithms of research
def find_route(start: str, stop: str):
steps_i = []
dist, prev = dijkstra(start)
cur = stop
steps_i.append((find_station(stop), dist[stop]))
while (cur_prev := prev[cur]) != start:
steps_i.insert(0, (find_station(cur_prev), dist[cur_prev]))
cur = cur_prev
steps_i.insert(0, (find_station(start), 0))
return steps_i
@lru_cache()
def find_neighbors(start):
neighs = []
for line in ALL_LINES.values():
line: Line
if start in line:
for rs in line.route:
rs: RouteStep
if rs.origin.code == start:
neighs.append((rs.target.code, rs.time))
if rs.target.code == start:
neighs.append((rs.origin.code, rs.time))
return neighs
def dijkstra(start: str):
dist = {node: INFINITY for node in ALL_DATA['stations']}
dist[start] = 0
prev = {node: None for node in ALL_DATA['stations']}
pq = [(0, start)]
heapify(pq)
visited = set()
while pq:
cur_dist, cur_node = heappop(pq)
if cur_node in visited:
continue
visited.add(cur_node)
for neigh, time in find_neighbors(cur_node):
tentative_dist = cur_dist + time
if tentative_dist < dist.setdefault(neigh, INFINITY):
dist[neigh] = tentative_dist
prev[neigh] = cur_node
heappush(pq, (tentative_dist, neigh))
return dist, prev
def main_rails(args):
if args.legacy:
global ALL_DATA, ALL_LINES
ALL_DATA = LEGACY_DATA
ALL_LINES = build_all_lines()
if args.search:
query = args.search.lower()
for st_code, st_name in ALL_DATA['stations'].items():
if query in take_first(st_name).lower():
print('*', st_code, st_name)
return
st_start = find_station(args.start)
st_end = find_station(args.end)
st_time = args.time
if not st_start or not st_end:
print('error: missing stations')
return
route = find_route(st_start.code, st_end.code)
for st_step, time in route:
print(HourMin(st_time + time), st_step.code, st_step.name)