Vairoj writes

Replace Python range() with xrange() for speed

A few weeks ago I stumbled on an interesting programming problem at CodeChef. The problem reminded me of dynamic programming thus I decided to try implementing a solution in Python.

Initially, the main part of the algorithm I implemented looked something like this:

for i in range(1, x):
  for j in range(i - 1, y, -1):
    # a few lines of basic comparisons,
    # calculations, and assignments
    # ......

Just a simple nested loop. When I ran the code with the test input where x is 100000 and y is 1000, however, it took 60 seconds to complete! Far from the problem’s time limit of 2 seconds. Although this is a nested loop, the running time is O(xy) which should not be that long.

To my surprise, the result from cProfile showed that most time were spent on calls to range function. After spending a few minutes googling for an explanation, it was apparent to me. When calling range function, it creates a list of all objects in the specified range all at once. In my case, the range call in the outer loop creates one million number objects right away. And that’s why it was slow.

The remedy? Use xrange instead of range. xrange is similar to range but it is a generator object which does not create any object immediately, but rather when you pull from the generator.

Changing range to xrange in the above code reduced the running time from 60 to 6 seconds right away. The moral of this story is to consider using xrange instead of range when the number is large. Note that this only applies to Python 2.x though. In Python 3 range provides an iterator over range and thus xrange was removed.

Continue further, cProfile showed that the time-consuming section now is calls to raw_input. Yes, I naively read a million line of inputs by calling raw_input in a loop. After changing the part to using sys.stdin.readlines(), the running time was decreased to 2.7 seconds. Here is the final version of my code:

(num_stations, fuel) = [int(x) for x in raw_input().split()]
# read station distances
dist = [int(x) for x in sys.stdin.readlines()]

stops = [MAX_DIST] * len(dist)
num_sols = [0] * len(dist)
stops[0] = 0
num_sols[0] = 1
for i in xrange(1, len(dist)):
  for j in xrange(i - 1, -1, -1):
    # backtrack until out of fuel range
    if dist[i] - dist[j] > fuel: break
    # assert dist[i] - dist[j] <= fuel
    cur_stop = 1 + stops[j]
    if stops[i] > cur_stop: # new min stops
      stops[i] = cur_stop
      num_sols[i] = num_sols[j] # reset number of solutions
    elif cur_stop == stops[i]: # another optimal solution
      num_sols[i] += 1
print "%d %d" % (stops - 1, num_sols) # remove final stop

I believe that there are still parts which can be further optimized, ideally to meet the 2 seconds time limit. If you have any idea or spot anything wrong, please let me know in the comment.

blog comments powered by Disqus