Post

[백준] 2473번 - 세 용액 [Java][C++]

[백준] 2473번 - 세 용액 [Java][C++]

문제 링크


1. 문제 풀이

주어진 정수 값의 용액들에 대해 세 용액의 합이 $0$ 에 최대한 가까운 조합을 찾아서 출력해야 하는 문제다. $N$ 이 최대 $5,000$ 이라서 3중 반복문을 활용한 브루트 포스로는 해결할 수 없는데, 이분 탐색 또는 투 포인터 알고리즘을 활용하면 해결할 수 있다. 세 용액의 합이 정수 타입 오버플로우를 일으킬 수 있음에 주의해야 한다.

1. 이분 탐색

이분 탐색의 경우 2중 for문으로 두 개의 용액을 선택한 후 두 용액의 합의 부호를 반전한 값을 Lower Bound, Upper Bound 이분 탐색으로 찾았다. 부호를 반전하면 세 용액의 합이 가장 $0$ 에 가깝게 되는 나머지 용액을 선택하게 되는 것이다.

2. 투 포인터

투 포인터의 경우 하나의 용액을 미리 선택하고 선택한 용액의 오른쪽 구간에 대해 양 끝에 포인터를 위치시킨 후 세 용액의 합이 가장 $0$ 에 가까운 순간을 탐색 및 갱신했다.


2. 코드

1. 이분 탐색 [Java]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import java.io.*;
import java.util.*;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;

        int N = Integer.parseInt(br.readLine());

        int[] arr = new int[N];
        st = new StringTokenizer(br.readLine());
        for (int i = 0; i < N; i++) {
            arr[i] = Integer.parseInt(st.nextToken());
        }
        Arrays.sort(arr);

        long sum = 3_000_000_000L;
        int[] ans = {0, 0, 0};

        for (int i = 0; i < N - 2; i++) {
            for (int j = i + 1; j < N - 1; j++) {
                int idx1 = lowerBound(arr, j + 1, -(arr[i] + arr[j]));
                int idx2 = upperBound(arr, j + 1, -(arr[i] + arr[j])) - 1;

                if (idx1 != N && Math.abs((long) arr[i] + arr[j] + arr[idx1]) < sum) {
                    sum = Math.abs((long) arr[i] + arr[j] + arr[idx1]);
                    ans[0] = arr[i];
                    ans[1] = arr[j];
                    ans[2] = arr[idx1];
                }

                if (idx2 != j && idx2 != N && Math.abs((long) arr[i] + arr[j] + arr[idx2]) < sum) {
                    sum = Math.abs((long) arr[i] + arr[j] + arr[idx2]);
                    ans[0] = arr[i];
                    ans[1] = arr[j];
                    ans[2] = arr[idx2];
                }
            }
        }

        System.out.printf("%d %d %d", ans[0], ans[1], ans[2]);
    }

    static int lowerBound(int[] arr, int left, int key) {
        int right = arr.length;

        while (left < right) {
            int mid = (left + right) / 2;

            if (arr[mid] < key) {
                left = mid + 1;
            } else {
                right = mid;
            }
        }

        return right;
    }

    static int upperBound(int[] arr, int left, int key) {
        int right = arr.length;

        while (left < right) {
            int mid = (left + right) / 2;

            if (arr[mid] <= key) {
                left = mid + 1;
            } else {
                right = mid;
            }
        }

        return right;
    }
}

2. 투 포인터 [Java]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import java.io.*;
import java.util.*;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st;

        int N = Integer.parseInt(br.readLine());

        int[] arr = new int[N];
        st = new StringTokenizer(br.readLine());
        for (int i = 0; i < N; i++) {
            arr[i] = Integer.parseInt(st.nextToken());
        }
        Arrays.sort(arr);

        long sum = 3_000_000_000L;
        int[] ans = {0, 0, 0};

        out:
        for (int i = 0; i < N - 2; i++) {
            long num = arr[i];

            int left = i + 1;
            int right = N - 1;

            while (left < right) {
                if (num + arr[left] + arr[right] > 0) {
                    if (Math.abs(num + arr[left] + arr[right]) < sum) {
                        sum = Math.abs(num + arr[left] + arr[right]);
                        ans[0] = arr[i];
                        ans[1] = arr[left];
                        ans[2] = arr[right];
                    }
                    right--;
                } else if (num + arr[left] + arr[right] < 0) {
                    if (Math.abs(num + arr[left] + arr[right]) < sum) {
                        sum = Math.abs(num + arr[left] + arr[right]);
                        ans[0] = arr[i];
                        ans[1] = arr[left];
                        ans[2] = arr[right];
                    }
                    left++;
                } else {
                    ans[0] = arr[i];
                    ans[1] = arr[left];
                    ans[2] = arr[right];
                    break out;
                }
            }
        }

        System.out.printf("%d %d %d", ans[0], ans[1], ans[2]);
    }
}

3. 이분 탐색 [C++]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <bits/stdc++.h>
using namespace std;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n;
    cin >> n;

    vector<long long> v(n);
    for (long long& x : v) cin >> x;
    sort(v.begin(), v.end());

    long long sum = 3000000000LL;
    vector<long long> ans(3);

    for (int i = 0; i < n - 2; i++) {
        for (int j = i + 1; j < n - 1; j++) {
            int idx1 = lower_bound(v.begin() + j + 1, v.end(), -(v[i] + v[j])) - v.begin();
            int idx2 = upper_bound(v.begin() + j + 1, v.end(), -(v[i] + v[j])) - v.begin() - 1;

            if (idx1 != n && llabs(v[i] + v[j] + v[idx1]) < sum) {
                sum = llabs(v[i] + v[j] + v[idx1]);
                ans = {v[i], v[j], v[idx1]};
            }

            if (idx2 != j && idx2 != n && llabs(v[i] + v[j] + v[idx2]) < sum) {
                sum = llabs(v[i] + v[j] + v[idx2]);
                ans = {v[i], v[j], v[idx2]};
            }
        }
    }

    cout << ans[0] << ' ' << ans[1] << ' ' << ans[2];
}

4. 투 포인터 [C++]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <bits/stdc++.h>
using namespace std;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n;
    cin >> n;

    vector<long long> v(n);
    for (long long& x : v) cin >> x;
    sort(v.begin(), v.end());

    long long sum = 3000000000LL;
    vector<long long> ans(3);

    for (int i = 0; i < n - 2; i++) {
        long long num = v[i];

        int l = i + 1;
        int r = n - 1;

        while (l < r) {
            if (num + v[l] + v[r] > 0) {
                if (llabs(num + v[l] + v[r]) < sum) {
                    sum = llabs(num + v[l] + v[r]);
                    ans = {v[i], v[l], v[r]};
                }
                r--;
            } else if (num + v[l] + v[r] < 0) {
                if (llabs(num + v[l] + v[r]) < sum) {
                    sum = llabs(num + v[l] + v[r]);
                    ans = {v[i], v[l], v[r]};
                }
                l++;
            } else {
                ans = {v[i], v[l], v[r]};

                cout << ans[0] << ' ' << ans[1] << ' ' << ans[2];
                return 0;
            }
        }
    }

    cout << ans[0] << ' ' << ans[1] << ' ' << ans[2];
}

3. 풀이 정보

1. 이분 탐색 [Java]

언어시간메모리코드 길이
Java 11836 ms16000 KB2164 B

2. 투 포인터 [Java]

언어시간메모리코드 길이
Java 11200 ms16272 KB1798 B

3. 이분 탐색 [C++]

언어시간메모리코드 길이
C++ 17400 ms2184 KB1019 B

4. 투 포인터 [C++]

언어시간메모리코드 길이
C++ 1736 ms2184 KB1157 B

This post is licensed under CC BY 4.0 by the author.