문제
https://www.acmicpc.net/problem/1167
트리의 지름이란, 트리에서 임의의 두 점 사이의 거리 중 가장 긴 것을 말한다. 트리의 지름을 구하는 프로그램을 작성하시오.
입력
트리가 입력으로 주어진다. 먼저 첫 번째 줄에서는 트리의 정점의 개수 V가 주어지고 (2 ≤ V ≤ 100,000)둘째 줄부터 V개의 줄에 걸쳐 간선의 정보가 다음과 같이 주어진다. 정점 번호는 1부터 V까지 매겨져 있다.
먼저 정점 번호가 주어지고, 이어서 연결된 간선의 정보를 의미하는 정수가 두 개씩 주어지는데, 하나는 정점번호, 다른 하나는 그 정점까지의 거리이다. 예를 들어 네 번째 줄의 경우 정점 3은 정점 1과 거리가 2인 간선으로 연결되어 있고, 정점 4와는 거리가 3인 간선으로 연결되어 있는 것을 보여준다. 각 줄의 마지막에는 -1이 입력으로 주어진다. 주어지는 거리는 모두 10,000 이하의 자연수이다.
출력
첫째 줄에 트리의 지름을 출력한다.
예제
나의 풀이 (실패)
import sys
input = sys.stdin.readline
V = int(input())
tree = [[0] * (V+1) for _ in range(V+1)]
visited = [False] * (V+1)
for _ in range(V):
info = list(map(int, input().split()))
node = info[0]
i = 1
while True:
if info[i] == -1:
break
a, b = info[i], info[i+1]
tree[node][a] = b
i += 2
- 트리의 구조를 인접 행렬에 저장하는 로직까지만 생각하고 그 다음으로 넘어가지 못했다.
- 예를 들어 `tree[a][b] = c`라면, `a` 정점과 `b` 정점은 거리가 `c`인 간선으로 연결되었음을 의미한다.
다른 사람의 풀이1 (BFS 큐)
import sys
from collections import deque
input = sys.stdin.readline
V = int(input())
tree = [[] for _ in range(V+1)]
for _ in range(V):
info = list(map(int, input().split()))
node = info[0]
for i in range(1, len(info)-1, 2):
tree[node].append((info[i], info[i+1]))
def BFS(start):
visited = [-1] * (V+1)
visited[start] = 0
queue = deque([start])
while queue:
now = queue.popleft()
for next_node, next_distance in tree[now]:
if visited[next_node] == -1:
queue.append(next_node)
visited[next_node] = visited[now] + next_distance
farthest_distance = max(visited)
farthest_node = visited.index(farthest_distance)
return farthest_node, visited
farthest_node, _ = BFS(1)
_, visited = BFS(farthest_node)
print(max(visited))
- 인접 리스트 형태로 트리의 구조를 저장한다.
- `BFS()`를 통해 가장 멀리 있는 노드와 거리를 담은 리스트를 구한다.
- `visited` 리스트는 시작 노드부터 각 노드까지의 거리를 저장한다.
- 1차적으로 임의의 노드에서 시작해서 정점까지의 거리가 가장 긴 노드를 찾는다.
- 2차적으로 정점까지의 거리가 가장 긴 노드에서 시작하는 가장 긴 거리을 구한다. -> 이것이 트리의 지름
다른 사람의 풀이2 (DFS 재귀함수)
import sys
input = sys.stdin.readline
sys.setrecursionlimit(10 ** 9)
V = int(input())
tree = [[] for _ in range(V+1)]
for _ in range(V):
info = list(map(int, input().split()))
node = info[0]
for i in range(1, len(info)-1, 2):
tree[node].append((info[i], info[i+1]))
def DFS(node, distance):
for next_node, next_distance in tree[node]:
if visited[next_node] == -1:
visited[next_node] = visited[node] + next_distance
DFS(next_node, distance + next_distance)
visited = [-1] * (V+1)
visited[1] = 0
DFS(1, 0)
farthest_distance = max(visited)
farthest_node = visited.index(farthest_distance)
visited = [-1] * (V+1)
visited[farthest_node] = 0
DFS(farthest_node, 0)
print(max(visited))
- 재귀함수를 이용해서도 풀 수 있다.
- 먼저 거리가 가장 긴 노드를 구하고, 그 노드에서 출발해서 트리의 지름을 구하는 로직까지 동일하다.
회고
- 비록 내가 직접 풀진 못했지만 다른 사람의 풀이를 보고 로직을 이해하고, 이해한 로직을 바탕으로 다시 풀어보는 것만으로도 많이 배울 수 있었다.