A few weeks ago I stumbled on an interesting programming problem at CodeChef. The problem reminded me of dynamic programming thus I decided to try implementing a solution in Python.
Initially, the main part of the algorithm I implemented looked something like this:
for i in range(1, x): for j in range(i - 1, y, -1): # a few lines of basic comparisons, # calculations, and assignments # ......
Just a simple nested loop. When I ran the code with the test input where
x is 100000 and
y is 1000, however, it took 60 seconds to complete! Far from the problem’s time limit of 2 seconds. Although this is a nested loop, the running time is O(xy) which should not be that long.
To my surprise, the result from cProfile showed that most time were spent on calls to
range function. After spending a few minutes googling for an explanation, it was apparent to me. When calling
range function, it creates a list of all objects in the specified range all at once. In my case, the range call in the outer loop creates one million number objects right away. And that’s why it was slow.
The remedy? Use xrange instead of range.
xrange is similar to
range but it is a generator object which does not create any object immediately, but rather when you pull from the generator.
xrange in the above code reduced the running time from 60 to 6 seconds right away. The moral of this story is to consider using
xrange instead of
range when the number is large. Note that this only applies to Python 2.x though. In Python 3
range provides an iterator over range and thus
xrange was removed.
Continue further, cProfile showed that the time-consuming section now is calls to
raw_input. Yes, I naively read a million line of inputs by calling
raw_input in a loop. After changing the part to using sys.stdin.readlines(), the running time was decreased to 2.7 seconds. Here is the final version of my code:
(num_stations, fuel) = [int(x) for x in raw_input().split()] # read station distances dist = [int(x) for x in sys.stdin.readlines()] stops = [MAX_DIST] * len(dist) num_sols =  * len(dist) stops = 0 num_sols = 1 for i in xrange(1, len(dist)): for j in xrange(i - 1, -1, -1): # backtrack until out of fuel range if dist[i] - dist[j] > fuel: break # assert dist[i] - dist[j] <= fuel cur_stop = 1 + stops[j] if stops[i] > cur_stop: # new min stops stops[i] = cur_stop num_sols[i] = num_sols[j] # reset number of solutions elif cur_stop == stops[i]: # another optimal solution num_sols[i] += 1 print "%d %d" % (stops - 1, num_sols) # remove final stop
I believe that there are still parts which can be further optimized, ideally to meet the 2 seconds time limit. If you have any idea or spot anything wrong, please let me know in the comment.