  • cr.rusucosmin replied 9 years ago

    Hello everyone. First I have to say this is a very nice problem. I managed to get an algorithm of complexitiy

    O(N*logN*log(MAXVALUE) + M*log N* log(MAXVALUE))

    using a Segment Tree and a Trie in every node of the tree. Still I'm getting WA and I can't manage to get why. I see it runs in 0.000 s, which is quite strange, cause it's a big complexity in my opinion, and I tend to think that the problem is not WA, is (maybe) a memory problem, cause I use a lot of memory for storing the Tries and the Segment Tree. I'm here to ask for help, maybe someone would tell me a test which my source fails, or tell me what is the problem. Here's my code :

    #include <vector>
    #include <queue>
    #include <string.h>
    #include <iostream>
    #include <algorithm>
    #include <cassert>
    using namespace std;
    const int maxn = 100005;
    const int bt = 31;
    int t, n, m, first[maxn], a[maxn], last[maxn], h, euler[maxn];
    vector <int> g[maxn];
    struct trie {
        int cnt;
        trie *sons[2];
        trie() {
            cnt = 0;
            memset(sons, 0, sizeof(sons));
        ~trie() {
            for(int i = 0 ; i < 2 ; ++ i)
                    delete sons[i];
    } *arb[maxn << 2];
    inline void dfs(int node, int father) {
        first[node] = ++ h;
        euler[h] = node;
        for(vector <int> :: iterator it = g[node].begin() ; it != g[node].end() ; ++ it)
            if(*it != father)
                dfs(*it, node);
        last[node] = h;
    const int lim = (1 << 20);
    int pos;
    char buff[lim];
    inline void getint(int &x) {
        x = 0;
        while(!isdigit(buff[pos])) {
            if(++ pos == lim) {
                fread(buff, 1, lim, stdin);
                pos = 0;
        while(isdigit(buff[pos])) {
            x = x * 10 + buff[pos] - '0';
            if(++ pos == lim) {
                fread(buff, 1, lim, stdin);
                pos = 0;
    inline void add(trie *&node, int value, int cnt, int bit) {
        if(bit == -1) {
            node->cnt += cnt;
        bool son = (value & (1 << bit));
            node->sons[son] = new trie();
        add(node->sons[son], value, cnt, bit - 1);
        node->cnt = 0;
        for(int i = 0 ; i < 2 ; ++ i)
                node->cnt += node->sons[i]->cnt;
    inline void build(int node, int st, int dr) {
        if(st == dr) {
                delete arb[node];
            arb[node] = new trie();
            add(arb[node], a[euler[st]], 1, bt);
            return ;
        int mid = ((st + dr) >> 1);
        build(node << 1, st, mid);
        build((node << 1) | 1, mid + 1, dr);
            delete arb[node];
        arb[node] = new trie();
        for(int i = st ; i <= dr ; ++ i)
            add(arb[node], a[euler[i]], 1, bt);
    vector <trie *> v, nxt[2];
    inline void query(int node, int st, int dr, int x, int y) {
        if(x <= st && dr <= y) {
            return ;
        int mid = ((st + dr) >> 1);
        if(x <= mid)
            query(node << 1, st, mid, x, y);
        if(mid < y)
            query((node << 1) | 1, mid + 1, dr, x, y);
    inline int solve(int k) {
        int ret = 0;
        for(int bit = bt ; bit >= 0 ; -- bit) {
            int sum = 0;
            for(vector <trie *> :: iterator it = v.begin() ; it != v.end() ; ++ it) {
                if(*it && (*it)->sons[0])
                    sum += (*it)->sons[0]->cnt;
                if(*it && (*it)->sons[0])
                if(*it && (*it)->sons[1])
            bool son = 0;
            if(k > sum) {
                k -= sum;
                son = 1;
            ret = ret * 2 + son;
            v = nxt[son];
        assert(ret <= 1000000000);
        return ret;
    int main() {
        #ifndef ONLINE_JUDGE
        freopen("", "r", stdin);
        freopen("uri1695.out", "w", stdout);
        while (t--) {
            h = 0;
            for(int i = 1 ; i <= n ; ++ i) {
            for(int i = 1 ; i < n ; ++ i) {
                int x, y;
            dfs(1, 0);
            build(1, 1, n);
            for(int i = 1 ; i <= m ; ++ i) {
                int x, k;
                query(1, 1, n, first[x], last[x]);
                printf("%d ", solve(k));

    I'm sorry if I wasted your time, but if you know what's the problem please tell me, I would really appreciate. PS. I don't know if it's fair play, but if URI can do that, please send me the test so I can debug and test my solution, that would be perfect :).