본문 바로가기

Algorithm

[LeetCode] Java, Python - 743. Network Delay Time

https://leetcode.com/problems/network-delay-time/

 

Network Delay Time - LeetCode

Level up your coding skills and quickly land a job. This is the best place to expand your knowledge and get prepared for your next interview.

leetcode.com

생각

모든 노드를 방문하고 가장 오래 걸리는 노드까지의 최단 비용을 구하기 위해 다익스트라 알고리즘을 사용하여 풀었습니다. 다익스트라 알고리즘을 구현하기위해 최소힙을 사용하여 문제를 풀었는데, heapq는 첫번째 요소를 기준으로 힙 정렬을 해줍니다. 저는 여기서 (노드, 비용) 이렇게 설정해줘서 계속 틀렸다고 나왔습니다ㅎㅎ

 

처음 시작지점을 힙에 넣고, 갈 수 있는 노드들을 힙에 넣어줍니다. 이때 내부적으로 최소힙정렬이 되기때문에 pop을 할때 최단거리를 가진 노드들이 나오게 됩니다. 이때 이미 방문한 노드의 비용이

 

dist 배열을 설정하고 '이미 저장된 dist 배열의 특정 노드 값  < 내가 pop한 노드의 비용 이라면 생략한다' 는 로직을 통해 구현했습니다. dist 배열을 언제 업데이트 해야할때는, '이미 방문한 노드의 비용 > 새로 계산한 비용' 일때 dist 배열을 업데이트 시켜줍니다. 그리고나서 인접노드를 heapq에 넣고 노드를 다시 업데이트해주면 됩니다.

 

Java Code

class Solution {
    public int networkDelayTime(int[][] times, int n, int k) {
        int[] dist = new int[n+1];
        Arrays.fill(dist, Integer.MAX_VALUE);

        Map<Integer, Map<Integer,Integer>> graph = new HashMap<>();
        for (int[] time : times) {
            graph.putIfAbsent(time[0], new HashMap<>());
            graph.get(time[0]).put(time[1], time[2]);
        }

        Queue<int[]> pq = new PriorityQueue<>((that, other) -> (that[0]-other[0]));
        pq.add(new int[] {0,k});
        dist[k] = 0;

        int ans = 0;

        while(!pq.isEmpty()) {
            int[] now = pq.poll();
            int now_cost = now[0];
            int now_node = now[1];
            if (dist[now_node] < now_cost) {
                continue;
            }

            ans = Math.max(ans, now_cost);

            if(graph.containsKey(now_node)) {
                for(int next_node : graph.get(now_node).keySet()) {
                    int next_cost = now_cost + graph.get(now_node).get(next_node);
                    if(dist[next_node] > next_cost) {
                        dist[next_node] = next_cost;
                        pq.add(new int[] {next_cost, next_node});
                    }

                }
            }


        }
        int cnt = 0;
        for (int i = 1; i <= n; i++) {
            if(dist[i] != Integer.MAX_VALUE) {
                cnt++;
            }
        }
        return cnt == n ? ans : -1;
    }
}

Python Code

from collections import deque
from collections import defaultdict
import heapq

class Solution:
    def networkDelayTime(self, times: List[List[int]], n: int, k: int) -> int:
        graph = defaultdict(list)
        dist = [sys.maxsize]*(n+1)
        visit = set()
        
        ans = 0
        
        for start,end,cost in times:
            graph[start].append((end,cost))
        
        q = [(0,k)] #비용, 시작노드 비용 기준으로 heap 정렬
        dist[k] = 0
        
        while q:
            now_cost, now = heapq.heappop(q)
            visit.add(now)
            if dist[now] < now_cost:
                continue
            
            ans = max(ans, now_cost)
            
            for next_node, w in graph[now]:
                next_weight = now_cost+w
                if dist[next_node] > next_weight :
                    dist[next_node] = next_weight
                    heapq.heappush(q, (next_weight, next_node))
                    
        return ans if len(visit) == n else -1