n = int(input())
l = list(map(int, input().split()))
p = 0
sum = 0
for i in range(len(l)):
j = len(l) - 1
while j >= 0:
if l[i] == l[j]:
s = abs(j - i)
sum +=s
break
j -=1
print(sum)