AtCoder Regular Contest 035 🧪 C - アットコーダー王国の交通事情 (Python)

久々投稿です!
おはこんばんにちは〜
今日は毎晩やっているバチャで面白い問題があったので紹介したいと思います!

問題はこちら!
https://atcoder.jp/contests/arc035/tasks/arc035_c

試験管🧪で難易度は青で1655diffとなってます!

解説記事の以上は初めてなので緊張してますが、頑張っていきます!(' _ ')

ダイクストラ法やワーシャルフロイド法を知っている方であれば、
灰色の方も理解できるように頑張って作ったので見ていただけると幸いです。

1. 問題の概要

要は全ての都市間の最短経路長の総和を求めろと言う問題。

総和を S, 2頂点間の最短距離を D(i, j) とすると、

 \displaystyle
S = \sum_{i=1}^{N-1} \sum_{j=i+1}^{N} D(i, j) \\

となるので、 S を出力する問題である。

ただ求めるだけならば、ダイクストラを貼って O(N^2) で終わるが、
この問題は K 個のクエリが与えられ、毎度辺が追加されるという特徴がある。

これをどう処理していくかが重要である。

2. 入力例と制約

 \displaystyle
N \quad M \\
A_1 \quad B_1 \quad C_1 \\
A_2 \quad B_2 \quad C_2 \\
: \\
A_M \quad B_M \quad C_M \\
K \\
X_1 \quad Y_1 \quad Z_1 \\
X_2 \quad Y_2 \quad Z_2 \\
: \\
X_K \quad Y_K \quad Z_K \\

  •  \displaystyle 1 \leqq N \leqq 400
  •  \displaystyle 1 \leqq M \leqq 1000
  •  \displaystyle 1 \leqq A_i, B_i \leqq N
  •  \displaystyle 1 \leqq C_i \leqq 1000
  •  \displaystyle 1 \leqq K \leqq 400
  •  \displaystyle 1 \leqq X_i, Y_i \leqq N
  •  \displaystyle 1 \leqq Z_i \leqq 1000

3. 解説

まず、計算量について考えてみよう。
今回の制約の特徴は何と言っても N が小さいことにある。

K 回のクエリごとに辺を追加し、最短距離を更新した後に、毎回全通りの D(i, j) を求めても O(KN^2) となるので間に合うことがわかる!

つまり、ダイクストラでクエリ処理前の全ての2頂点間の最短距離を O(N^2logN) 求めた後に、辺を追加していくことを考える。

では、頂点 (x, y) 間に z で行ける距離を追加した場合どのような更新が行われるかを考えよう。

以下、 dist[a][b] を a から b までの最短距離と定義する。

1つ目のパターンは、頂点 (a, b) の距離を a → x → y → b で辿っていく場合が考えられる。

f:id:ryusuke_920:20210714011914p:plain
a → x → y → b の順で辺を辿っていく場合

2つ目のパターンは、頂点 (a, b) の距離を a → y → x → b で辿っていく場合である。

f:id:ryusuke_920:20210714012002p:plain
a → y → x → b の順で辺を辿っていく場合

したがって、このように3パターン(更新せずに a → b に行くパターンを含めた)のうち、最短となるものを求めて更新していけば良いことがわかる。

a → b の組み合わせは 、

 \displaystyle
S = \sum_{a=1}^{N-1} \sum_{b=a+1}^{N} dist[a][b] \\
通りあるので、これは O(N^2) で処理することが可能となる。

したがって、

  1. 前処理としてダイクストラ法(ワーシャルフロイド法でも可能)で全ての2頂点間の最短距離を求める。
  2. K 回のクエリごとに最短距離を更新 O(N^2) し、クエリごとに毎回全ての2頂点間の全通り O(N^2) を試して総和を求めることができる。

したがって、計算量は O(N^2logN + KN^2) となる。

4. ACコード

ACしたコードはこちらになります。

import sys
input = sys.stdin.readline
from heapq import heappush, heappop

# ダイクストラ法で最短経路を求める
def dijkstra(s, graph):
    dist = [INF] * n
    check = [False] * n
    dist[s] = 0
    q = [(0, s)]
    while q:
        v = heappop(q)[1]
        if check[v]: continue
        check[v] = True
        for i, j in graph[v]:
            if check[i] != False: continue
            if dist[i] <= dist[v] + j: continue
            dist[i] = dist[v] + j
            heappush(q, (dist[i], i))
    return dist

INF = 10 ** 9
n, m = map(int,input().split())

# ノードの追加
g = [[] for _ in range(n)]
for _ in range(m):
    a, b, c = map(int,input().split())
    g[a - 1].append((b - 1, c))
    g[b - 1].append((a - 1, c))

# 全ての2頂点間の最短経路を求める
dist = []
for i in range(n):
    dist.append(dijkstra(i, g))

k = int(input())
for _ in range(k):
    x, y, z = map(int,input().split())

    # 最短距離の更新を行う
    for i in range(n):
        for j in range(n):
            if i == j: continue
            dist[i][j] = min(dist[i][j], dist[i][x - 1] + dist[y - 1][j] + z, dist[i][y - 1] + dist[j][x - 1] + z)

    # 総和 S を求める
    ans = 0
    for i in range(n - 1):
        for j in range(i + 1, n):
            ans += dist[i][j]

    print(ans)

5. 最後に

久々に解説記事を書きました。

途中ではPower Pointを使用した画像を入れてみましたが、どうでしょうか?

また気分が向いたら記事を書きたいと思います!!