Skip to content

Commit

Permalink
ARROW-945: [GLib] Add a Lua example to show Torch integration
Browse files Browse the repository at this point in the history
Author: Kouhei Sutou <[email protected]>

Closes #637 from kou/glib-lua-to-torch-tensor and squashes the following commits:

4aba395 [Kouhei Sutou] [GLib] Add a Lua example to show Torch integration
  • Loading branch information
kou authored and wesm committed May 5, 2017
1 parent 80b72d4 commit bcf073c
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 1 deletion.
1 change: 1 addition & 0 deletions c_glib/example/lua/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ dist_lua_example_DATA = \
README.md \
read-batch.lua \
read-stream.lua \
stream-to-torch-tensor.lua \
write-batch.lua \
write-stream.lua
5 changes: 5 additions & 0 deletions c_glib/example/lua/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,8 @@ Here are example codes in this directory:

* `read-stream.lua`: It shows how to read Arrow array from file in
stream mode.

* `stream-to-torch-tensor.lua`: It shows how to read Arrow array
from file in stream mode and convert it to
[Torch](http://torch.ch/)'s
[`Tensor` object](http://torch7.readthedocs.io/en/rtd/tensor/index.html).
2 changes: 1 addition & 1 deletion c_glib/example/lua/read-stream.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ local reader = Arrow.StreamReader.open(input)

local i = 0
while true do
local record_batch = reader:get_next_record_batch(i)
local record_batch = reader:get_next_record_batch()
if not record_batch then
break
end
Expand Down
101 changes: 101 additions & 0 deletions c_glib/example/lua/stream-to-torch-tensor.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.

local lgi = require 'lgi'
local Arrow = lgi.Arrow

local torch = require 'torch'

Arrow.Array.torch_types = function(self)
return nil
end

Arrow.Array.to_torch = function(self)
local types = self:torch_types()
if not types then
return nil
end

local storage_type = types[1]
local tensor_type = types[2]

local size = self:get_length()
local storage = storage_type(size)
if not storage then
return nil
end

for i = 1, size do
storage[i] = self:get_value(i - 1)
end
return tensor_type(storage)
end

Arrow.UInt8Array.torch_types = function(self)
return {torch.ByteStorage, torch.ByteTensor}
end

Arrow.Int8Array.torch_types = function(self)
return {torch.CharStorage, torch.CharTensor}
end

Arrow.Int16Array.torch_types = function(self)
return {torch.ShortStorage, torch.ShortTensor}
end

Arrow.Int32Array.torch_types = function(self)
return {torch.IntStorage, torch.IntTensor}
end

Arrow.Int64Array.torch_types = function(self)
return {torch.LongStorage, torch.LongTensor}
end

Arrow.FloatArray.torch_types = function(self)
return {torch.FloatStorage, torch.FloatTensor}
end

Arrow.DoubleArray.torch_types = function(self)
return {torch.DoubleStorage, torch.DoubleTensor}
end


local input_path = arg[1] or "/tmp/stream.arrow";

local input = Arrow.MemoryMappedInputStream.new(input_path)
local reader = Arrow.StreamReader.open(input)

local i = 0
while true do
local record_batch = reader:get_next_record_batch()
if not record_batch then
break
end

print(string.rep("=", 40))
print("record-batch["..i.."]:")
for j = 0, record_batch:get_n_columns() - 1 do
local column = record_batch:get_column(j)
local column_name = record_batch:get_column_name(j)
print(" "..column_name..":")
print(column:to_torch())
end

i = i + 1
end

input:close()

0 comments on commit bcf073c

Please sign in to comment.