- 论坛徽章:
- 0
|
- for k in range(num_epochs):
- # 给数据洗个牌
- index_shuf = [i for i in range(len(questions_train))]
- shuffle(index_shuf)
- # 一一取出 问题,答案,和图片
- questions_train = [questions_train[i] for i in index_shuf]
- answers_train = [answers_train[i] for i in index_shuf]
- images_train = [images_train[i] for i in index_shuf]
- # batch分组
- for qu_batch,an_batch,im_batch in zip(grouper(questions_train, batch_size, fillvalue=questions_train[-1]),
- grouper(answers_train, batch_size, fillvalue=answers_train[-1]),
- grouper(images_train, batch_size, fillvalue=images_train[-1])):
- X_q_batch = get_questions_matrix_sum(qu_batch, nlp)
- X_i_batch = get_images_matrix(im_batch, id_map, VGGfeatures)
- X_batch = np.hstack((X_q_batch, X_i_batch))
- Y_batch = get_answers_matrix(an_batch, labelencoder)
- loss = model.train_on_batch(X_batch, Y_batch)
- progbar.add(batch_size, values=[("train loss", loss)])
- # 并且告诉模型,隔多久,存一次模型,比如这里,model_save_interval是10
- if k%model_save_interval == 0:
- model.save_weights(model_file_name + '_epoch_{:02d}.hdf5'.format(k))
- # 把最终的模型也存下来
- model.save_weights(model_file_name + '_epoch_{:02d}.hdf5'.format(k))
复制代码- ---------------------------------------------------------------------------
- TypeError Traceback (most recent call last)
- <ipython-input-16-f74b00680e52> in <module>()
- 15 grouper(answers_train, batch_size, fillvalue=answers_train[-1]),
- 16 grouper(images_train, batch_size, fillvalue=images_train[-1])):
- ---> 17 X_q_batch = get_questions_matrix_sum(qu_batch, nlp)
- 18 X_i_batch = get_images_matrix(im_batch, id_map, VGGfeatures)
- 19 X_batch = np.hstack((X_q_batch, X_i_batch))
- <ipython-input-9-175d4cdea086> in get_questions_matrix_sum(questions, nlp)
- 11 # assert not isinstance(questions, basestring)
- 12 nb_samples = len(questions)
- ---> 13 word_vec_dim = nlp(questions[0])[0].vector.shape[0]
- 14 questions_matrix = np.zeros((nb_samples, word_vec_dim))
- 15 for i in range(len(questions)):
- /anaconda2/lib/python2.7/site-packages/spacy/language.pyc in __call__(self, text, disable)
- 338 raise ValueError(Errors.E088.format(length=len(text),
- 339 max_length=self.max_length))
- --> 340 doc = self.make_doc(text)
- 341 for name, proc in self.pipeline:
- 342 if name in disable:
- /anaconda2/lib/python2.7/site-packages/spacy/language.pyc in make_doc(self, text)
- 370
- 371 def make_doc(self, text):
- --> 372 return self.tokenizer(text)
- 373
- 374 def update(self, docs, golds, drop=0., sgd=None, losses=None):
- TypeError: Argument 'string' has incorrect type (expected unicode, got str)
复制代码 环境:anaconda python:2.7
这个代码之前在python2的环境下是可以正常运行的,求助。
TypeError: Argument 'string' has incorrect type (expected unicode, got str)
|
|