一道Google top coder的850分例题及解答
原题:
假设有这样一种字符串,它们的长度不大于 26 ,而且若一个这样的字符串其长度为 m ,则这个字符串必定由 a, b, c ... z 中的前 m 个字母构成,同时我们保证每个字母出现且仅出现一次。比方说某个字符串长度为 5 ,那么它一定是由 a, b, c, d, e 这 5 个字母构成,不会多一个也不会少一个。嗯嗯,这样一来,一旦长度确定,这个字符串中有哪些字母也就确定了,唯一的区别就是这些字母的前后顺序而已。
现在我们用一个由大写字母 A 和 B 构成的序列来描述这类字符串里各个字母的前后顺序:
l 如果字母 b 在字母 a 的后面,那么序列的第一个字母就是 A (After),否则序列的第一个字母就是 B (Before);
l 如果字母 c 在字母 b 的后面,那么序列的第二个字母就是 A ,否则就是 B;
l 如果字母 d 在字母 c 的后面,那么 …… 不用多说了吧?直到这个字符串的结束。
这规则甚是简单,不过有个问题就是同一个 AB 序列,可能有多个字符串都与之相符,比方说序列"ABA",就有"acdb"、"cadb"等等好几种可能性。说的专业一点,这一个序列实际上对应了一个字符串集合。那么现在问题来了:给你一个这样的AB 序列,问你究竟有多少个不同的字符串能够与之相符?或者说这个序列对应的字符串集合有多大?注意,只要求个数,不要求枚举所有的字符串。
注:如果结果大于10亿就返回-1。
我的最终解答(没有考虑溢出的情况):
// CODE 1
// the best way
// O(N^2)
int countABbest(const string& AB)
{
assert(AB.find_first_not_of("AB") == string::npos);
vector<int> current, next; // should we reserve these vectors?
current.push_back(1);
for (string::const_iterator letter = AB.begin();
letter != AB.end(); ++letter) {
next.resize(current.size()+1); // or next.insert(next.end(), 2, 0);
next[0] = 0; // in fact, we could set the entire vector to zero
if (*letter == 'A') {
partial_sum(current.begin(), current.end(), next.begin()+1);
} else {
partial_sum(current.rbegin(), current.rend(), next.begin()+1);
reverse(next.begin(), next.end());
}
swap(current, next);
}
return accumulate(current.begin(), current.end(), 0);
}
int main()
{
constchar* AB = "ABBAAB";
printf("'%s' : %d\n", AB, countABbest(AB));
}
下面谈一谈我在解决这个问题时的思路。
第一步 初步分析
以下“字符串”特指题目中提到的由小写字母a、b、c等等组成的字符串,每个字母出现且仅出现一次。显然题目要求我们写一个函数f,f的输入是一个长度为v 的AB序列w,w代表了一个字符串集合s(集合中的元素都是长度为m(m=v+1)的字符串),f的返回值是这个集合的元素个数|s|,即|s|=f(w)。用高中学过的一点排列组合知识,可分析知:
1. 长度为m的字符串有m! 个(’!’ 表示阶乘)因为这相当于m个不同字母的全排列;
2. 长度为v的AB序列有2^v个(’^’ 表示指数)因为每个位置有2种可能,一共有v个位置;
3. 由于2^v <= m! (m=v+1),所以AB序列的数目不大于字符串的数目。
4. 每个字符串刚好有一个AB序列与之对应。比如对于字符串”abdec”,我们很容易得知b在a后,c在b后,d在c前,e在d后,因此它对应的AB序列为”AABA”。可见拿到一个字符串,立刻就能求出它对应的那一个AB序列。
5. 每个AB序列至少对应一个字符串(当然也对应多个,因为字符串数目远远大于AB序列数目)。比如任取一个AB序列”ABA”,很容易构造出与它对应的字符串:
i.
b在a后,得”ab”;
ii.
c在b前,得”acb”或”cab”;
iii.
d在c后,拿”acb”来说,可得”acdb”和”acbd”;拿”cab”来说,可得 ”cdab”、”cadb”和”cabd”;这样一共构造了5个与”ABA”对应的字符串,而且不会再有别的字符串了(why?)。
其实我们已经找到了蛮力解决问题的办法。
6. 根据4、5,得知如果穷举出长度为v的AB序列(共2^v个),并计算每个序列对应的字符串数目,那么把所有这些数目加起来,应该等于(v+1)!,这可以用作我们算法的一个检验。
7. 其实这可以看作集合的划分,把一个有 m! 个元素的集合U划分为2^v个不相交的子集s_0, s_1, s_{2^v–1},每个子集s_i是一个类别,每个字符串都属于一个类别,问题转变为求给定类别中有多少个元素。
第二步 蛮力解决
在想到前面的分析之前,我先用一种蛮力办法部分地解决了这个问题,思路是拿到一个长度为v的AB序列,穷举所有长度为v+1的字符串,遇到匹配的就记录下来。这样得到第一个程序,这个程序虽然效率极低,但可以用来检验后面程序的正确性,是个标竿。
// CODE 2
bool match(const string& AB, const string& str)
{
// many ways to improve this function, but we won’t bother it.
for (size_t i = 0; i < AB.length(); ++i) {
size_t first = str.find('a'+i);
size_t second = str.find('a'+i+1);
assert(first != string::npos && second != string::npos);
if (AB[i] == 'A' && first > second) {
returnfalse;
} elseif (AB[i] == 'B' && first < second) {
returnfalse;
}
}
returntrue;
}
// the stupid way
// O(N! * N^2)
int countAB(const string& AB)
{
assert(AB.find_first_not_of("AB") == string::npos);
string str;
int count = 0;
int m = (int)AB.length() + 1;
// construct the initial string
for (int i = 0; i < m; ++i) {
str.push_back('a'+i);
}
do {
if (match(AB, str)) {
printf("%s, ", str.c_str());
count++;
}
} while (next_permutation(str.begin(), str.end()));
return count;
}
上面这个程序是以AB序列为中心,想办法找到与它匹配的字符串。为了看它能否通过第6点分析的检验,我写了一个enumAB(int v)函数,用来穷举长度为v的所有AB序列,并做检验(检验基本靠眼)。
// CODE 3
void enumAB(int v)
{
assert(0 <= v && v < 26);
int nAB = 1 << v;
int total = 0;
for (int i = 0; i < nAB; ++i) {
string AB;
for (int bit = v-1; bit >= 0; --bit) {
if (i & (1 << bit)) {
AB.push_back('B');
} else {
AB.push_back('A');
}
}
int count = countAB(AB);
total += count;
printf("%s : %d\n", AB.c_str(), count);
}
printf("\nTotal strings: %d\n", total);
}
以下是enumAB(4)的运行结果(5!=120,初步检验通过):
AAAA : 1
AAAB : 4
AABA : 9
AABB : 6
ABAA : 9
ABAB : 16
ABBA : 11
ABBB : 4
BAAA : 4
BAAB : 11
BABA : 16
BABB : 9
BBAA : 6
BBAB : 9
BBBA : 4
BBBB : 1
Total strings: 120
如果想穷举所有AB序列和它们对应的字符串,还可以用一种效率稍高的蛮力算法,以字符串为中心,穷举所有长度为m的字符串,把它归入相应的AB序列名下。代码如下。
// CODE 4
string getAB(const string& str)
{
constchar* alphabet = "abcdefghijklmnopqrstuvwxyz";
assert(str.find_first_not_of(alphabet, 0, str.length()) == string::npos);
int pos[26] = {0};
char AB[26] = {0};
int m = (int)str.length();
for (int i = 0; i < m; ++i) {
pos[str[i]-'a'] = i;
}
for (int i = 0; i < m-1; ++i) {
AB[i] = pos[i] < pos[i+1] ? 'A' : 'B';
}
return AB; // we are not return the local char array, but a string object.
}
void enumStr(int m)
{
string str;
int nAB = 0;
for (int i = 0; i < m; ++i) {
str.push_back(char('a'+i));
}
map<string, vector<string> > AB2strs;
do {
string AB = getAB(str);
//printf("%s is of %s\n", str.c_str(), AB.c_str());
AB2strs[AB].push_back(str);
} while (next_permutation(str.begin(), str.end()));
for (map<string, vector<string> >::iterator it = AB2strs.begin();
it != AB2strs.end(); ++it) {
++nAB;
printf("%s (%d): ", it->first.c_str(), it->second.size());
for (vector<string>::iterator str = it->second.begin();
str != it->second.end(); ++str) {
printf("%s, ", str->c_str());
}
printf("\n");
}
printf("\nTotal ABs : %d\n", nAB);
}
以下是enumStr(4)的运行结果(2^3=8,初步检验通过):
AAA (1): abcd,
AAB (3): abdc, adbc, dabc,
ABA
(5): acbd, acdb, cabd, cadb, cdab,
ABB (3): adcb, dacb, dcab,
BAA (3): bacd, bcad, bcda,
BAB (5): badc, bdac, bdca, dbac, dbca,
BBA (3): cbad, cbda, cdba,
BBB (1): dcba,
Total ABs : 8
第三步 进阶分析
我们也可以根据前面第5点分析,做出一个更高效的蛮力算法,不过蛮力毕竟是蛮力,还是让我们动动脑筋,做个真正高效的算法吧。
我第一次拿到这个问题时,先用蛮力算法打印出前面的结果,试图分析其规律,没成功。便又在纸上演算了了一阵,发现其实可以递推解决(当然也可以递归解决),以下内容最好在纸上演算。比如对于序列”AAA”,字母d只可能在第3号位置出现一次(abcd);递推一下,对于序列”AAAB”,e在d前,那么e可以在第0、1、2、3号位置各出现一次(eabcd、aebcd、abecd、abced)。
又比如根据以前面第5点分析,如果我们知道对于序列”AB”,字母c可能在第0号位置出现一次(cab)、在第1号位置出现一次(acb);那么对于序列”ABA”,字母d会在第1、2、3号位置分别出现1、2、2次,因此”ABA”对应的字符串共有5个;同理对于序列”ABB”,字母d会在第0、1号位置分别出现2、1次,因此”ABB”对应的字符串共有3个。
继续递推,对于序列”ABBA”,e在d后,那么e可以在第1、2、3、4号位置分别出现2、3、3、3次(具体说来,对于d在第0号位置出现2次,那么e可以在第1、2、3、4号位置各出现2次;d在第1号位置出现1次,那么e可以在第2、3、4号位置各出现1次,对位加起来就得到前面“2、3、3、3”的结果),因此”ABBA”对应的字符串共有11个。
到这里,我们已经发现递推的规律了:对于AB序列w,用二维数组occurs[][]表示第letter个字母在位置pos出现的次数occurs[letter][pos](这个说法不太严格,应该说是w的前面长度为letter的子序列对应的字符串中,最大那个字母出现的位置和次数,呵呵,还是比较绕口)。如果字母p在位置q1出现n1次,而AB序列的当前元素为’A’,那么字母p+1会在位置q1+1, q1+2, . . . , p各出现n1次;如果AB序列的当前元素为’B’,那么字母p+1会在位置0, 1, . . . , q1各出现n1次;如果字母p还在q2位置出现了n2次,那么对于’A’ 情况,字母p+1还会在位置q2+1, q2+2, . . . , p各出现n2次;那么对于’B’ 情况,字母p+1还会在位置0, 1, . . . , q2各出现n2次。需要把这些情况都累加起来。
对于序列”ABBAA”,递推表如下:
1, 0, 0, 0, 0, 0 字母a在位置0出现1次
0, 1, 0, 0, 0, 0 字母b在位置1出现1次
1, 1, 0, 0, 0, 0 字母c在位置0、1分别出现1次
2, 1, 0, 0, 0, 0 字母d在位置0、1分别出现2、1次
0, 2, 3, 3, 3, 0 字母e在位置1、2、3、4分别出现2、3、3、3次
0, 0, 2, 5, 8, 11 字母f 在位置2、3、4、5分别出现2、5、8、11次
可知对应的字符串有26个。如果细心,已经能发现递推中的部分和(partial sum)关系。
第四步 解决
既然递推关系有了,很容易就能写出代码。这个算法的复杂度是O(N^3)。
// CODE 5
// the better way
// O(N^3)
int countABbetter(const string& AB)
{
assert(AB.find_first_not_of("AB") == string::npos);
int v = (int)AB.length();
int m = v + 1;
// 'letter' at 'pos' occurs 'occurs[letter][pos]' times.
vector<vector<int> > occurs(m, vector<int>(m, 0));
// letter 'a' at pos 0, 1 time
occurs[0][0] = 1;
for (int letter = 1; letter < m; ++letter) {
for (int pos = 0; pos < letter; ++pos) {
int first_pos = 0;
int last_pos = 0;
if (AB[letter-1] == 'A') {
// after current pos
first_pos = pos + 1;
last_pos = letter;
} else {
assert(AB[letter-1] == 'B');
// before (and at) current pos
first_pos = 0;
last_pos = pos;
}
int occur = occurs[letter-1][pos];
for (int t = first_pos; t <= last_pos; ++t) {
occurs[letter][t] += occur;
assert(occurs[letter][t] >= 0);
}
}
}
return accumulate(occurs[m-1].begin(), occurs[m-1].end(), 0);
}
第五步 优化
前面提过一句,在递推的过程中其实隐藏了一个“部分和”的关系,利用这一性质,可以很容易地将复杂度降为O(N^2),而且递推只是根据当前字母的出现位置退出下一字母的出现位置,因此可以省去2维数组,改用两个vector就行了。最后的代码就是前面一开始列出的 CODE 1。
第六步 展望
我猜测算法的复杂度能进一步降到 O(N log N),不过自己已经没有能力实现了。另外,为了附庸风雅一把,我发现整个递推算法的过程如果用矩阵来描述,会变得相当清楚。比如对于序列”ABAAB”,很容易构造矩阵A1、B2、A3、A4、B5(每个矩阵都是6阶方阵),初始向量x=[1 0 0 0 0 0]T,生成向量y=B5*A4*A3*B2*A1*x,那么对应的字符串有sum(y)个(sum表示y的各分量之和)。
注:也可以定义初始向量x=[1],矩阵A1是2x1、矩阵B2是3x2、矩阵A3是4x3、……、矩阵B5是6x5,一样可以计算出向量y。
例如:(这些矩阵中的元素都是0或1,排列起来像三角形(因为是求部分和),很有规律的。)
A1 = [0; 1]
B2 = [1 1; 0 1; 0 0]
A3 = [0 0 0; 1 0 0; 1 1 0; 1 1 1]
A4 = [0 0 0 0; 1 0 0 0; 1 1 0 0; 1 1 1 0; 1 1 1 1]
B5 = [1 1 1 1 1; 0 1 1 1 1; 0 0 1 1 1; 0 0 0 1 1; 0 0 0 0 1; 0 0 0 0 0]
算出y = B5*A4*A3*B2*A1 = [9 9 9 8 5 0] T
sum(y) = 40,与前面程序的结果相同。
. 完 .
Trackback: http://tb.blog.csdn.net/TrackBack.aspx?PostId=653418