// 처음 내 코드
import java.io.*;
import java.util.*;
public class Main {
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int N = Integer.parseInt(st.nextToken());
int M = Integer.parseInt(st.nextToken());
st = new StringTokenizer(br.readLine());
int[] arr = new int[N];
arr[0] = Integer.parseInt(st.nextToken());
for (int i = 1; i < N; i++) arr[i] = Integer.parseInt(st.nextToken()) + arr[i-1];
// System.out.println("Arrays.toString(arr) = " + Arrays.toString(arr));
int cnt = 0;
for (int i = 1; i < N; i++) {
// System.out.println("----");
for (int j = 0; j < N - i; j++) {
int t = arr[i+j] - arr[i];
if (t % M == 0) {
// System.out.println(i + " " + (i+j) + " = " + t);
cnt++;
}
}
}
System.out.println(cnt);
}
}
그렇게 됐을 때 O(N^2) 시간 복잡도로 N이 1 <= N <= 10^6 이므로 무조건 시간 초과날 수 밖에 없는 코드를 작성했고 실제로 시간 초과가 났다.
알고 보니 아주 수학적인 아이디어? 접근법?이 필요했다.
핵심 !
우선 입력에서 arr = [1, 2, 3, 1, 2] 일때 누적합 배열은 아래와 같다.
누적합 배열에서 prefixSum[i], prefixSum[j]이 있을 때 prefixSum[i] % M == prefixSum[j] % M 이면 prefixSum[j] - prefixSum[i]는 M으로 나누어떨어짐 이 포인트가 핵심적인 포인트였다.
따라서, modCount라는 크기가 M인 배열을 생성했고, prefixSum 배열에서 각 요소들을 M으로 나누어 그 나머지 값이 modCount에 몇 번 등장했는지를 세어주었습니다.
prefixSum % M 인 배열은 [1, 0, 0, 1, 0] 이므로 modCount 배열의 결과는 아래와 같다.(나머지가 0인거 3개, 1인거 2개, 2인거 0개)
나머지가 0인 경우는, 그 자체로도 구간합이 M으로 나눠떨어지므로 해당 개수만큼 정답에 더해준다.
또한, 위의 접근법에 따라 같은 나머지를 가진 누적합(prefixSum)이 2개 이상 존재한다면, 이들 중 2개를 선택했을 때 그 사이의 구간합도 M으로 나누어떨어지므로, 나머지가 같은 누적합 쌍의 수를 조합(조합 공식: nC2 = n × (n - 1) / 2)을 통해 계산하여 정답에 더해주면 된다.
그 결과 최종 코드이다.
import java.io.*;
import java.util.*;
public class Main {
public static void main(String[] args) throws IOException {
// 입력
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int N = Integer.parseInt(st.nextToken()); // 수의 개수
int M = Integer.parseInt(st.nextToken()); // 나누는 수
long[] prefixSum = new long[N + 1]; // 누적합 배열
long[] modCount = new long[M]; // 나머지 개수 배열
st = new StringTokenizer(br.readLine());
for (int i = 1; i <= N; i++) {
int num = Integer.parseInt(st.nextToken());
prefixSum[i] = prefixSum[i - 1] + num;
int mod = (int)(prefixSum[i] % M);
if (mod < 0) mod += M; // 음수 방지
modCount[mod]++;
}
// 누적합 자체가 나누어떨어지는 경우 (mod == 0)
long count = modCount[0];
// 같은 나머지에서 2개 뽑는 조합
// 같은 나머지끼리 2개 고르면 → 그 사이의 구간합은 M으로 나누어떨어진다.
for (int i = 0; i < M; i++) {
count += modCount[i] * (modCount[i] - 1) / 2;
}
System.out.println(count);
}
}
https://www.acmicpc.net/problem/2143 에서 부분합 크기 갯수만큼 카운팅하는 과정에서 N < 1000이라 직접 부부합을 브루트 포스로 확인하는 로직이 있었는데 그때 종종 활용하면 좋을 것같아 추가로 부분합을 구하는 메소드를 정리해봤다.
private static Map<Integer, Long> getSubarray(int[] arr) {
Map<Integer, Long> sumCounts = new HashMap<>();
int n = arr.length;
for (int start = 0; start < n; start++) {
int sum = 0;
for (int end = start; end < n; end++) {
sum += arr[end];
sumCounts.put(sum, sumCounts.getOrDefault(sum, 0L) + 1L);
}
}
return sumCounts;
}