1744 - 수 묶기
본문
길이가 $N$인 수열이 주어졌을 때, 그 수열의 합을 구하려고 한다. 하지만, 그냥 그 수열의 합을 모두 더해서 구하는 것이 아니라, 수열의 두 수를 묶으려고 한다. 어떤 수를 묶으려고 할 때, 위치에 상관없이 묶을 수 있다. 하지만, 같은 위치에 있는 수(자기 자신)를 묶는 것은 불가능하다. 그리고 어떤 수를 묶게 되면, 수열의 합을 구할 때 묶은 수는 서로 곱한 후에 더한다.
예를 들면, 어떤 수열이 $\lbrace 0, 1, 2, 4, 3, 5 \rbrace$일 때, 그냥 이 수열의 합을 구하면 $0+1+2+4+3+5 = 15$이다. 하지만, $2$와 $3$을 묶고, $4$와 $5$를 묶게 되면, $0+1+(2 \times 3)+(4 \times 5) = 27$이 되어 최대가 된다.
수열의 모든 수는 단 한번만 묶거나, 아니면 묶지 않아야한다.
수열이 주어졌을 때, 수열의 각 수를 적절히 묶었을 때, 그 합이 최대가 되게 하는 프로그램을 작성하시오.
입력
첫째 줄에 수열의 크기 $N$이 주어진다. $N$은 $50$보다 작은 자연수이다. 둘째 줄부터 $N$개의 줄에 수열의 각 수가 주어진다. 수열의 수는 $-1\,000$보다 크거나 같고, $1\,000$보다 작거나 같은 정수이다.
출력
수를 합이 최대가 나오게 묶었을 때 합을 출력한다. 정답은 항상 $2^{31}$보다 작다.
제한
시간 제한 | 메모리 제한 |
---|---|
2sec | 128MB |
풀이
모든 수는 묶이지 않거나 “단 한번” 묶일 수 있다는 것을 기억하자. 한번 묶이거나 묶이지 않았다고 판단된 수는 바로 결과에 더해주는 방향으로 처리하자.
두 수를 묶는 기준부터 정해보도록 하자. 어떤 두 수 $a$와 $b$가 정해졌다면, $a+b \leq a \times b$라면 합치고 아니라면 합치지 않아야할 것이다. 그럼 두 수를 정하는 방법만 찾아내면 정답을 찾을 수 있을 것 같다.
그럼 두 수를 정하는 방법을 찾아야겠다. 적용될 연산은 곱셈 연산 $\times$. 가장 큰 두 수를 곱하면 더욱 큰 수가 된다는 것이 일반적이다. 다만 이 문제에는 양수가 아닌 수들도 입력으로 주어진다는 점이 까다로운 점이다.
몇 가지 상황으로 나눠보도록하자. 어떤 두 수의 부호가 다를 경우에는, 곱하면 무조건 음수이므로 곱셈 연산을 하면 안된다. 그럼 곱셈 연산을 취할 가능성이 있을 때는 두 수의 부호가 같을 때가 될 것이다. 두 수가 모두 양수일 때는 앞서 얘기하듯 $a+b \leq a \times b$일 때 곱해주면 된다. 부호가 음수라면, 더하는 것보다는 곱해서 양수로 만드는 것이 결과값을 더 키울 수 있는 방법이 될 것이다.
그럼 결과값을 크게 하기 위해 “두 수 $a, b$를 어떻게 찾아낼까”가 관건이 될 것이다. 곱셈 연산의 결과값을 크게 하려면 당연히 절댓값이 큰 순서대로 곱해가면 된다. 수학적으로는 간단한 문제.
다만 이것을 프로그래밍으로 풀어나가려면 조금 애를 먹는다. 일단 우선순위 큐를 2개 만들어서 각각 양수와 음수를 저장한다. 이 큐는 각각 절댓값이 큰 순서대로 정렬하면 된다. 아니라면 절댓값의 순서대로 정렬해서 사용해도 된다. 나는 정렬하는 방법을 사용했고, 스택구조를 활용할 목적으로 역순으로 정렬 후 사용했다. 절댓값순으로 정렬했다면 위에서 설립한 규칙대로 모두 곱해서 더해주면 된다. 양수의 경우에는 더한 경우보다 곱한 경우가 더 결과값이 큰지 확인을 한번 거쳐야한다. 모든 수를 묶은 다음 결과값에 포함되지 않은 수가 남아있는지 꼭 확인하고 결과를 출력하자.
- 참고 알고리즘 : 우선순위 큐
코드
사용 언어 : C
최종 수정일 : 2024 / 9 / 9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include <stdio.h>
#define MAX_IDX 50
int pos[MAX_IDX], pos_len;
int neg[MAX_IDX], neg_len;
int n;
#define abs(x) (((x) > 0) ? (x) : (-(x)))
int cmp(int* a, int* b) { return abs(*a) > abs(*b); }
int main() {
scanf("%d", &n);
for (int i = 0; i < n; ++i) {
int a;
scanf("%d", &a);
if (a > 0) {
pos[pos_len++] = a;
} else {
neg[neg_len++] = a;
}
}
qsort(pos, pos_len, sizeof(int), cmp);
qsort(neg, neg_len, sizeof(int), cmp);
int retval = 0;
/* solve() */
while (pos_len > 1) {
int a = pos[--pos_len];
int b = pos[--pos_len];
if (a * b > a + b) {
retval += (a * b);
} else {
retval += (a + b);
}
}
if (pos_len > 0) {
retval += pos[0];
}
while (neg_len > 1) {
int a = neg[--neg_len];
int b = neg[--neg_len];
retval += (a * b);
}
if (neg_len > 0) {
retval += neg[0];
}
printf("%d", retval);
return 0;
}