기타/코딩테스트
[Python] 백준 1238번 - 파티
lazy man
2023. 3. 28. 17:59
문제정보
https://www.acmicpc.net/problem/1238
1238번: 파티
첫째 줄에 N(1 ≤ N ≤ 1,000), M(1 ≤ M ≤ 10,000), X가 공백으로 구분되어 입력된다. 두 번째 줄부터 M+1번째 줄까지 i번째 도로의 시작점, 끝점, 그리고 이 도로를 지나는데 필요한 소요시간 Ti가 들어
www.acmicpc.net
해결전략
노드, 간선, 가중치가 주어지고 최단 거리를 구하는 문제입니다. 다익스트라를 이용해서 해결이 가능하지만 일반적인 다이스트라로는 시간 초과가 발생합니다. heapq를 이용하여 가중치가 가장 작은 노드를 얻으면서 복잡도를 낮출 수 있습니다.
코드
(정답 코드)
import sys
import heapq
input = sys.stdin.readline
INF = int(1e9)
n, m, x = map(int, input().split())
graph = [[] for _ in range(n+1)]
# 간선 입력. a에서 b까지가는데 c가 걸린다
for _ in range(m):
a, b, c = map(int, input().split())
graph[a].append((b, c))
def dijkstra(start, end):
q = []
heapq.heappush(q, (0, start))
distance = [INF] * (n+1)
distance[start] = 0
while q:
dist, now = heapq.heappop(q)
# 이미 처리된 노드라면
if distance[now] < dist:
continue
for j in graph[now]:
cost = dist + j[1]
if cost < distance[j[0]]:
distance[j[0]] = cost
heapq.heappush(q, (cost, j[0]))
return distance[end]
result = 0
for i in range(1, n+1):
total = dijkstra(i, x) + dijkstra(x, i)
if result < total:
result = total
print(result)
(시간 초과)
import sys
input = sys.stdin.readline
INF = int(1e9)
n, m, x = map(int, input().split())
graph = [[] for _ in range(n+1)]
# 간선 입력. a에서 b까지가는데 c가 걸린다
for _ in range(m):
a, b, c = map(int, input().split())
graph[a].append((b, c))
# 방문하지 않은 노드중 거리가 가장 가까운 노드 취득
def get_smallest_node(distance, visited):
min_value = INF
index = 0
for i in range(1, n+1):
if distance[i] < min_value and not visited[i]:
min_value = distance[i]
index = i
return index
# min(n to x + x to n)
def dijkstra(start, end):
visited = [False] * (n+1)
distance = [INF] * (n+1)
# 시작점 방문 처리
visited[start] = True
# 자기 자신까지의 시간 0
distance[start] = 0
for j in graph[start]:
distance[j[0]] = j[1]
for i in range(n-1):
now = get_smallest_node(distance, visited)
visited[now] = True
for j in graph[now]:
cost = distance[now] + j[1]
if cost < distance[j[0]]:
distance[j[0]] = cost
return distance[end]
result = 0
for i in range(1, n+1):
total = dijkstra(i, x) + dijkstra(x, i)
if result < total:
result = total
print(result)