Replace Python range() with xrange() for speed
A few weeks ago I stumbled on an interesting programming problem at CodeChef. The problem reminded me of dynamic programming thus I decided to try implementing a solution in Python.
Initially, the main part of the algorithm I implemented looked something like this:
for i in range(1, x):
for j in range(i - 1, y, -1):
# a few lines of basic comparisons,
# calculations, and assignments
# ......
Just a simple nested loop. When I ran the code with the test input where x
is 100000 and y
is 1000, however, it took 60 seconds to complete! Far from the problem's time limit of 2 seconds. Although this is a nested loop, the running time is O(xy) which should not be that long.
To my surprise, the result from cProfile showed that most time were spent on calls to range
function. After spending a few minutes googling for an explanation, it was apparent to me. When calling range
function, it creates a list of all objects in the specified range all at once. In my case, the range call in the outer loop creates one million number objects right away. And that's why it was slow.
The remedy? Use xrange instead of range. xrange
is similar to range
but it is a generator object which does not create any object immediately, but rather when you pull from the generator.
Changing range
to xrange
in the above code reduced the running time from 60 to 6 seconds right away. The moral of this story is to consider using xrange
instead of range
when the number is large. Note that this only applies to Python 2.x though. In Python 3 range
provides an iterator over range and thus xrange
was removed.
Continue further, cProfile showed that the time-consuming section now is calls to raw_input
. Yes, I naively read a million line of inputs by calling raw_input
in a loop. After changing the part to using sys.stdin.readlines(), the running time was decreased to 2.7 seconds. Here is the final version of my code:
(num_stations, fuel) = [int(x) for x in raw_input().split()]
# read station distances
dist = [int(x) for x in sys.stdin.readlines()]
stops = [MAX_DIST] * len(dist)
num_sols = [0] * len(dist)
stops[0] = 0
num_sols[0] = 1
for i in xrange(1, len(dist)):
for j in xrange(i - 1, -1, -1):
# backtrack until out of fuel range
if dist[i] - dist[j] > fuel: break
# assert dist[i] - dist[j] <= fuel
cur_stop = 1 + stops[j]
if stops[i] > cur_stop: # new min stops
stops[i] = cur_stop
num_sols[i] = num_sols[j] # reset number of solutions
elif cur_stop == stops[i]: # another optimal solution
num_sols[i] += 1
print "%d %d" % (stops - 1, num_sols) # remove final stop
I believe that there are still parts which can be further optimized, ideally to meet the 2 seconds time limit. If you have any idea or spot anything wrong, please let me know in the comment.